HPCC2025/SAC_Continuous/SAC.py

145 lines
6.4 KiB
Python
Raw Normal View History

2025-03-21 16:04:42 +08:00
from utils import Actor, Double_Q_Critic
import torch.nn.functional as F
import numpy as np
import torch
import copy
class SAC_countinuous():
def __init__(self, **kwargs):
# Init hyperparameters for agent, just like "self.gamma = opt.gamma, self.lambd = opt.lambd, ..."
self.__dict__.update(kwargs)
self.tau = 0.005
self.actor = Actor(self.state_dim, self.action_dim,
(self.net_width, self.net_width)).to(self.dvc)
self.actor_optimizer = torch.optim.Adam(
self.actor.parameters(), lr=self.a_lr)
self.q_critic = Double_Q_Critic(
self.state_dim, self.action_dim, (self.net_width, self.net_width)).to(self.dvc)
self.q_critic_optimizer = torch.optim.Adam(
self.q_critic.parameters(), lr=self.c_lr)
self.q_critic_target = copy.deepcopy(self.q_critic)
# Freeze target networks with respect to optimizers (only update via polyak averaging)
for p in self.q_critic_target.parameters():
p.requires_grad = False
self.replay_buffer = ReplayBuffer(
self.state_dim, self.action_dim, max_size=int(1e6), dvc=self.dvc)
if self.adaptive_alpha:
# Target Entropy = dim(A) (e.g. , -6 for HalfCheetah-v2) as given in the paper
self.target_entropy = torch.tensor(-self.action_dim,
dtype=float, requires_grad=True, device=self.dvc)
# We learn log_alpha instead of alpha to ensure alpha>0
self.log_alpha = torch.tensor(
np.log(self.alpha), dtype=float, requires_grad=True, device=self.dvc)
self.alpha_optim = torch.optim.Adam([self.log_alpha], lr=self.c_lr)
def select_action(self, state, deterministic):
# only used when interact with the env
with torch.no_grad():
state = torch.FloatTensor(state[np.newaxis, :]).to(self.dvc)
a, _ = self.actor(state, deterministic, with_logprob=False)
return a.cpu().numpy()[0]
def train(self,):
s, a, r, s_next, dw = self.replay_buffer.sample(self.batch_size)
# ----------------------------- ↓↓↓↓↓ Update Q Net ↓↓↓↓↓ ------------------------------#
with torch.no_grad():
a_next, log_pi_a_next = self.actor(
s_next, deterministic=False, with_logprob=True)
target_Q1, target_Q2 = self.q_critic_target(s_next, a_next)
target_Q = torch.min(target_Q1, target_Q2)
# Dead or Done is tackled by Randombuffer
target_Q = r + (~dw) * self.gamma * \
(target_Q - self.alpha * log_pi_a_next)
# Get current Q estimates
current_Q1, current_Q2 = self.q_critic(s, a)
q_loss = F.mse_loss(current_Q1, target_Q) + \
F.mse_loss(current_Q2, target_Q)
self.q_critic_optimizer.zero_grad()
q_loss.backward()
self.q_critic_optimizer.step()
# ----------------------------- ↓↓↓↓↓ Update Actor Net ↓↓↓↓↓ ------------------------------#
# Freeze critic so you don't waste computational effort computing gradients for them when update actor
for params in self.q_critic.parameters():
params.requires_grad = False
a, log_pi_a = self.actor(s, deterministic=False, with_logprob=True)
current_Q1, current_Q2 = self.q_critic(s, a)
Q = torch.min(current_Q1, current_Q2)
a_loss = (self.alpha * log_pi_a - Q).mean()
self.actor_optimizer.zero_grad()
a_loss.backward()
self.actor_optimizer.step()
for params in self.q_critic.parameters():
params.requires_grad = True
# ----------------------------- ↓↓↓↓↓ Update alpha ↓↓↓↓↓ ------------------------------#
if self.adaptive_alpha:
# We learn log_alpha instead of alpha to ensure alpha>0
alpha_loss = -(self.log_alpha * (log_pi_a +
self.target_entropy).detach()).mean()
self.alpha_optim.zero_grad()
alpha_loss.backward()
self.alpha_optim.step()
self.alpha = self.log_alpha.exp()
# ----------------------------- ↓↓↓↓↓ Update Target Net ↓↓↓↓↓ ------------------------------#
for param, target_param in zip(self.q_critic.parameters(), self.q_critic_target.parameters()):
target_param.data.copy_(
self.tau * param.data + (1 - self.tau) * target_param.data)
def save(self, EnvName, timestep):
torch.save(self.actor.state_dict(),
"./model/{}_actor{}.pth".format(EnvName, timestep))
torch.save(self.q_critic.state_dict(),
"./model/{}_q_critic{}.pth".format(EnvName, timestep))
def load(self, EnvName, timestep):
self.actor.load_state_dict(torch.load(
"./model/{}_actor{}.pth".format(EnvName, timestep), map_location=self.dvc))
self.q_critic.load_state_dict(torch.load(
"./model/{}_q_critic{}.pth".format(EnvName, timestep), map_location=self.dvc))
class ReplayBuffer():
def __init__(self, state_dim, action_dim, max_size, dvc):
self.max_size = max_size
self.dvc = dvc
self.ptr = 0
self.size = 0
self.s = torch.zeros((max_size, state_dim),
dtype=torch.float, device=self.dvc)
self.a = torch.zeros((max_size, action_dim),
dtype=torch.float, device=self.dvc)
self.r = torch.zeros((max_size, 1), dtype=torch.float, device=self.dvc)
self.s_next = torch.zeros(
(max_size, state_dim), dtype=torch.float, device=self.dvc)
self.dw = torch.zeros((max_size, 1), dtype=torch.bool, device=self.dvc)
def add(self, s, a, r, s_next, dw):
# 每次只放入一个时刻的数据
self.s[self.ptr] = torch.from_numpy(s).to(self.dvc)
self.a[self.ptr] = torch.from_numpy(a).to(
self.dvc) # Note that a is numpy.array
self.r[self.ptr] = r
self.s_next[self.ptr] = torch.from_numpy(s_next).to(self.dvc)
self.dw[self.ptr] = dw
self.ptr = (self.ptr + 1) % self.max_size # 存满了又重头开始存
self.size = min(self.size + 1, self.max_size)
def sample(self, batch_size):
ind = torch.randint(0, self.size, device=self.dvc, size=(batch_size,))
return self.s[ind], self.a[ind], self.r[ind], self.s_next[ind], self.dw[ind]