HPCC2025/SAC_Continuous/SAC.py
weixin_46229132 5b468deb9d SAC
2025-03-21 16:04:42 +08:00

145 lines
6.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from utils import Actor, Double_Q_Critic
import torch.nn.functional as F
import numpy as np
import torch
import copy
class SAC_countinuous():
def __init__(self, **kwargs):
# Init hyperparameters for agent, just like "self.gamma = opt.gamma, self.lambd = opt.lambd, ..."
self.__dict__.update(kwargs)
self.tau = 0.005
self.actor = Actor(self.state_dim, self.action_dim,
(self.net_width, self.net_width)).to(self.dvc)
self.actor_optimizer = torch.optim.Adam(
self.actor.parameters(), lr=self.a_lr)
self.q_critic = Double_Q_Critic(
self.state_dim, self.action_dim, (self.net_width, self.net_width)).to(self.dvc)
self.q_critic_optimizer = torch.optim.Adam(
self.q_critic.parameters(), lr=self.c_lr)
self.q_critic_target = copy.deepcopy(self.q_critic)
# Freeze target networks with respect to optimizers (only update via polyak averaging)
for p in self.q_critic_target.parameters():
p.requires_grad = False
self.replay_buffer = ReplayBuffer(
self.state_dim, self.action_dim, max_size=int(1e6), dvc=self.dvc)
if self.adaptive_alpha:
# Target Entropy = dim(A) (e.g. , -6 for HalfCheetah-v2) as given in the paper
self.target_entropy = torch.tensor(-self.action_dim,
dtype=float, requires_grad=True, device=self.dvc)
# We learn log_alpha instead of alpha to ensure alpha>0
self.log_alpha = torch.tensor(
np.log(self.alpha), dtype=float, requires_grad=True, device=self.dvc)
self.alpha_optim = torch.optim.Adam([self.log_alpha], lr=self.c_lr)
def select_action(self, state, deterministic):
# only used when interact with the env
with torch.no_grad():
state = torch.FloatTensor(state[np.newaxis, :]).to(self.dvc)
a, _ = self.actor(state, deterministic, with_logprob=False)
return a.cpu().numpy()[0]
def train(self,):
s, a, r, s_next, dw = self.replay_buffer.sample(self.batch_size)
# ----------------------------- ↓↓↓↓↓ Update Q Net ↓↓↓↓↓ ------------------------------#
with torch.no_grad():
a_next, log_pi_a_next = self.actor(
s_next, deterministic=False, with_logprob=True)
target_Q1, target_Q2 = self.q_critic_target(s_next, a_next)
target_Q = torch.min(target_Q1, target_Q2)
# Dead or Done is tackled by Randombuffer
target_Q = r + (~dw) * self.gamma * \
(target_Q - self.alpha * log_pi_a_next)
# Get current Q estimates
current_Q1, current_Q2 = self.q_critic(s, a)
q_loss = F.mse_loss(current_Q1, target_Q) + \
F.mse_loss(current_Q2, target_Q)
self.q_critic_optimizer.zero_grad()
q_loss.backward()
self.q_critic_optimizer.step()
# ----------------------------- ↓↓↓↓↓ Update Actor Net ↓↓↓↓↓ ------------------------------#
# Freeze critic so you don't waste computational effort computing gradients for them when update actor
for params in self.q_critic.parameters():
params.requires_grad = False
a, log_pi_a = self.actor(s, deterministic=False, with_logprob=True)
current_Q1, current_Q2 = self.q_critic(s, a)
Q = torch.min(current_Q1, current_Q2)
a_loss = (self.alpha * log_pi_a - Q).mean()
self.actor_optimizer.zero_grad()
a_loss.backward()
self.actor_optimizer.step()
for params in self.q_critic.parameters():
params.requires_grad = True
# ----------------------------- ↓↓↓↓↓ Update alpha ↓↓↓↓↓ ------------------------------#
if self.adaptive_alpha:
# We learn log_alpha instead of alpha to ensure alpha>0
alpha_loss = -(self.log_alpha * (log_pi_a +
self.target_entropy).detach()).mean()
self.alpha_optim.zero_grad()
alpha_loss.backward()
self.alpha_optim.step()
self.alpha = self.log_alpha.exp()
# ----------------------------- ↓↓↓↓↓ Update Target Net ↓↓↓↓↓ ------------------------------#
for param, target_param in zip(self.q_critic.parameters(), self.q_critic_target.parameters()):
target_param.data.copy_(
self.tau * param.data + (1 - self.tau) * target_param.data)
def save(self, EnvName, timestep):
torch.save(self.actor.state_dict(),
"./model/{}_actor{}.pth".format(EnvName, timestep))
torch.save(self.q_critic.state_dict(),
"./model/{}_q_critic{}.pth".format(EnvName, timestep))
def load(self, EnvName, timestep):
self.actor.load_state_dict(torch.load(
"./model/{}_actor{}.pth".format(EnvName, timestep), map_location=self.dvc))
self.q_critic.load_state_dict(torch.load(
"./model/{}_q_critic{}.pth".format(EnvName, timestep), map_location=self.dvc))
class ReplayBuffer():
def __init__(self, state_dim, action_dim, max_size, dvc):
self.max_size = max_size
self.dvc = dvc
self.ptr = 0
self.size = 0
self.s = torch.zeros((max_size, state_dim),
dtype=torch.float, device=self.dvc)
self.a = torch.zeros((max_size, action_dim),
dtype=torch.float, device=self.dvc)
self.r = torch.zeros((max_size, 1), dtype=torch.float, device=self.dvc)
self.s_next = torch.zeros(
(max_size, state_dim), dtype=torch.float, device=self.dvc)
self.dw = torch.zeros((max_size, 1), dtype=torch.bool, device=self.dvc)
def add(self, s, a, r, s_next, dw):
# 每次只放入一个时刻的数据
self.s[self.ptr] = torch.from_numpy(s).to(self.dvc)
self.a[self.ptr] = torch.from_numpy(a).to(
self.dvc) # Note that a is numpy.array
self.r[self.ptr] = r
self.s_next[self.ptr] = torch.from_numpy(s_next).to(self.dvc)
self.dw[self.ptr] = dw
self.ptr = (self.ptr + 1) % self.max_size # 存满了又重头开始存
self.size = min(self.size + 1, self.max_size)
def sample(self, batch_size):
ind = torch.randint(0, self.size, device=self.dvc, size=(batch_size,))
return self.s[ind], self.a[ind], self.r[ind], self.s_next[ind], self.dw[ind]