from utils import Actor, Double_Q_Critic import torch.nn.functional as F import numpy as np import torch import copy class SAC_countinuous(): def __init__(self, **kwargs): # Init hyperparameters for agent, just like "self.gamma = opt.gamma, self.lambd = opt.lambd, ..." self.__dict__.update(kwargs) self.tau = 0.005 self.actor = Actor(self.state_dim, self.action_dim, (self.net_width, self.net_width)).to(self.dvc) self.actor_optimizer = torch.optim.Adam( self.actor.parameters(), lr=self.a_lr) self.q_critic = Double_Q_Critic( self.state_dim, self.action_dim, (self.net_width, self.net_width)).to(self.dvc) self.q_critic_optimizer = torch.optim.Adam( self.q_critic.parameters(), lr=self.c_lr) self.q_critic_target = copy.deepcopy(self.q_critic) # Freeze target networks with respect to optimizers (only update via polyak averaging) for p in self.q_critic_target.parameters(): p.requires_grad = False self.replay_buffer = ReplayBuffer( self.state_dim, self.action_dim, max_size=int(1e6), dvc=self.dvc) if self.adaptive_alpha: # Target Entropy = −dim(A) (e.g. , -6 for HalfCheetah-v2) as given in the paper self.target_entropy = torch.tensor(-self.action_dim, dtype=float, requires_grad=True, device=self.dvc) # We learn log_alpha instead of alpha to ensure alpha>0 self.log_alpha = torch.tensor( np.log(self.alpha), dtype=float, requires_grad=True, device=self.dvc) self.alpha_optim = torch.optim.Adam([self.log_alpha], lr=self.c_lr) def select_action(self, state, deterministic): # only used when interact with the env with torch.no_grad(): state = torch.FloatTensor(state[np.newaxis, :]).to(self.dvc) a, _ = self.actor(state, deterministic, with_logprob=False) return a.cpu().numpy()[0] def train(self,): s, a, r, s_next, dw = self.replay_buffer.sample(self.batch_size) # ----------------------------- ↓↓↓↓↓ Update Q Net ↓↓↓↓↓ ------------------------------# with torch.no_grad(): a_next, log_pi_a_next = self.actor( s_next, deterministic=False, with_logprob=True) target_Q1, target_Q2 = self.q_critic_target(s_next, a_next) target_Q = torch.min(target_Q1, target_Q2) # Dead or Done is tackled by Randombuffer target_Q = r + (~dw) * self.gamma * \ (target_Q - self.alpha * log_pi_a_next) # Get current Q estimates current_Q1, current_Q2 = self.q_critic(s, a) q_loss = F.mse_loss(current_Q1, target_Q) + \ F.mse_loss(current_Q2, target_Q) self.q_critic_optimizer.zero_grad() q_loss.backward() self.q_critic_optimizer.step() # ----------------------------- ↓↓↓↓↓ Update Actor Net ↓↓↓↓↓ ------------------------------# # Freeze critic so you don't waste computational effort computing gradients for them when update actor for params in self.q_critic.parameters(): params.requires_grad = False a, log_pi_a = self.actor(s, deterministic=False, with_logprob=True) current_Q1, current_Q2 = self.q_critic(s, a) Q = torch.min(current_Q1, current_Q2) a_loss = (self.alpha * log_pi_a - Q).mean() self.actor_optimizer.zero_grad() a_loss.backward() self.actor_optimizer.step() for params in self.q_critic.parameters(): params.requires_grad = True # ----------------------------- ↓↓↓↓↓ Update alpha ↓↓↓↓↓ ------------------------------# if self.adaptive_alpha: # We learn log_alpha instead of alpha to ensure alpha>0 alpha_loss = -(self.log_alpha * (log_pi_a + self.target_entropy).detach()).mean() self.alpha_optim.zero_grad() alpha_loss.backward() self.alpha_optim.step() self.alpha = self.log_alpha.exp() # ----------------------------- ↓↓↓↓↓ Update Target Net ↓↓↓↓↓ ------------------------------# for param, target_param in zip(self.q_critic.parameters(), self.q_critic_target.parameters()): target_param.data.copy_( self.tau * param.data + (1 - self.tau) * target_param.data) def save(self, EnvName, timestep): torch.save(self.actor.state_dict(), "./model/{}_actor{}.pth".format(EnvName, timestep)) torch.save(self.q_critic.state_dict(), "./model/{}_q_critic{}.pth".format(EnvName, timestep)) def load(self, EnvName, timestep): self.actor.load_state_dict(torch.load( "./model/{}_actor{}.pth".format(EnvName, timestep), map_location=self.dvc)) self.q_critic.load_state_dict(torch.load( "./model/{}_q_critic{}.pth".format(EnvName, timestep), map_location=self.dvc)) class ReplayBuffer(): def __init__(self, state_dim, action_dim, max_size, dvc): self.max_size = max_size self.dvc = dvc self.ptr = 0 self.size = 0 self.s = torch.zeros((max_size, state_dim), dtype=torch.float, device=self.dvc) self.a = torch.zeros((max_size, action_dim), dtype=torch.float, device=self.dvc) self.r = torch.zeros((max_size, 1), dtype=torch.float, device=self.dvc) self.s_next = torch.zeros( (max_size, state_dim), dtype=torch.float, device=self.dvc) self.dw = torch.zeros((max_size, 1), dtype=torch.bool, device=self.dvc) def add(self, s, a, r, s_next, dw): # 每次只放入一个时刻的数据 self.s[self.ptr] = torch.from_numpy(s).to(self.dvc) self.a[self.ptr] = torch.from_numpy(a).to( self.dvc) # Note that a is numpy.array self.r[self.ptr] = r self.s_next[self.ptr] = torch.from_numpy(s_next).to(self.dvc) self.dw[self.ptr] = dw self.ptr = (self.ptr + 1) % self.max_size # 存满了又重头开始存 self.size = min(self.size + 1, self.max_size) def sample(self, batch_size): ind = torch.randint(0, self.size, device=self.dvc, size=(batch_size,)) return self.s[ind], self.a[ind], self.r[ind], self.s_next[ind], self.dw[ind]