145 lines
6.4 KiB
Python
145 lines
6.4 KiB
Python
![]() |
from utils import Actor, Double_Q_Critic
|
|||
|
import torch.nn.functional as F
|
|||
|
import numpy as np
|
|||
|
import torch
|
|||
|
import copy
|
|||
|
|
|||
|
|
|||
|
class SAC_countinuous():
|
|||
|
def __init__(self, **kwargs):
|
|||
|
# Init hyperparameters for agent, just like "self.gamma = opt.gamma, self.lambd = opt.lambd, ..."
|
|||
|
self.__dict__.update(kwargs)
|
|||
|
self.tau = 0.005
|
|||
|
|
|||
|
self.actor = Actor(self.state_dim, self.action_dim,
|
|||
|
(self.net_width, self.net_width)).to(self.dvc)
|
|||
|
self.actor_optimizer = torch.optim.Adam(
|
|||
|
self.actor.parameters(), lr=self.a_lr)
|
|||
|
|
|||
|
self.q_critic = Double_Q_Critic(
|
|||
|
self.state_dim, self.action_dim, (self.net_width, self.net_width)).to(self.dvc)
|
|||
|
self.q_critic_optimizer = torch.optim.Adam(
|
|||
|
self.q_critic.parameters(), lr=self.c_lr)
|
|||
|
self.q_critic_target = copy.deepcopy(self.q_critic)
|
|||
|
# Freeze target networks with respect to optimizers (only update via polyak averaging)
|
|||
|
for p in self.q_critic_target.parameters():
|
|||
|
p.requires_grad = False
|
|||
|
|
|||
|
self.replay_buffer = ReplayBuffer(
|
|||
|
self.state_dim, self.action_dim, max_size=int(1e6), dvc=self.dvc)
|
|||
|
|
|||
|
if self.adaptive_alpha:
|
|||
|
# Target Entropy = −dim(A) (e.g. , -6 for HalfCheetah-v2) as given in the paper
|
|||
|
self.target_entropy = torch.tensor(-self.action_dim,
|
|||
|
dtype=float, requires_grad=True, device=self.dvc)
|
|||
|
# We learn log_alpha instead of alpha to ensure alpha>0
|
|||
|
self.log_alpha = torch.tensor(
|
|||
|
np.log(self.alpha), dtype=float, requires_grad=True, device=self.dvc)
|
|||
|
self.alpha_optim = torch.optim.Adam([self.log_alpha], lr=self.c_lr)
|
|||
|
|
|||
|
def select_action(self, state, deterministic):
|
|||
|
# only used when interact with the env
|
|||
|
with torch.no_grad():
|
|||
|
state = torch.FloatTensor(state[np.newaxis, :]).to(self.dvc)
|
|||
|
a, _ = self.actor(state, deterministic, with_logprob=False)
|
|||
|
return a.cpu().numpy()[0]
|
|||
|
|
|||
|
def train(self,):
|
|||
|
s, a, r, s_next, dw = self.replay_buffer.sample(self.batch_size)
|
|||
|
|
|||
|
# ----------------------------- ↓↓↓↓↓ Update Q Net ↓↓↓↓↓ ------------------------------#
|
|||
|
with torch.no_grad():
|
|||
|
a_next, log_pi_a_next = self.actor(
|
|||
|
s_next, deterministic=False, with_logprob=True)
|
|||
|
target_Q1, target_Q2 = self.q_critic_target(s_next, a_next)
|
|||
|
target_Q = torch.min(target_Q1, target_Q2)
|
|||
|
# Dead or Done is tackled by Randombuffer
|
|||
|
target_Q = r + (~dw) * self.gamma * \
|
|||
|
(target_Q - self.alpha * log_pi_a_next)
|
|||
|
|
|||
|
# Get current Q estimates
|
|||
|
current_Q1, current_Q2 = self.q_critic(s, a)
|
|||
|
|
|||
|
q_loss = F.mse_loss(current_Q1, target_Q) + \
|
|||
|
F.mse_loss(current_Q2, target_Q)
|
|||
|
self.q_critic_optimizer.zero_grad()
|
|||
|
q_loss.backward()
|
|||
|
self.q_critic_optimizer.step()
|
|||
|
|
|||
|
# ----------------------------- ↓↓↓↓↓ Update Actor Net ↓↓↓↓↓ ------------------------------#
|
|||
|
# Freeze critic so you don't waste computational effort computing gradients for them when update actor
|
|||
|
for params in self.q_critic.parameters():
|
|||
|
params.requires_grad = False
|
|||
|
|
|||
|
a, log_pi_a = self.actor(s, deterministic=False, with_logprob=True)
|
|||
|
current_Q1, current_Q2 = self.q_critic(s, a)
|
|||
|
Q = torch.min(current_Q1, current_Q2)
|
|||
|
|
|||
|
a_loss = (self.alpha * log_pi_a - Q).mean()
|
|||
|
self.actor_optimizer.zero_grad()
|
|||
|
a_loss.backward()
|
|||
|
self.actor_optimizer.step()
|
|||
|
|
|||
|
for params in self.q_critic.parameters():
|
|||
|
params.requires_grad = True
|
|||
|
|
|||
|
# ----------------------------- ↓↓↓↓↓ Update alpha ↓↓↓↓↓ ------------------------------#
|
|||
|
if self.adaptive_alpha:
|
|||
|
# We learn log_alpha instead of alpha to ensure alpha>0
|
|||
|
alpha_loss = -(self.log_alpha * (log_pi_a +
|
|||
|
self.target_entropy).detach()).mean()
|
|||
|
self.alpha_optim.zero_grad()
|
|||
|
alpha_loss.backward()
|
|||
|
self.alpha_optim.step()
|
|||
|
self.alpha = self.log_alpha.exp()
|
|||
|
|
|||
|
# ----------------------------- ↓↓↓↓↓ Update Target Net ↓↓↓↓↓ ------------------------------#
|
|||
|
for param, target_param in zip(self.q_critic.parameters(), self.q_critic_target.parameters()):
|
|||
|
target_param.data.copy_(
|
|||
|
self.tau * param.data + (1 - self.tau) * target_param.data)
|
|||
|
|
|||
|
def save(self, EnvName, timestep):
|
|||
|
torch.save(self.actor.state_dict(),
|
|||
|
"./model/{}_actor{}.pth".format(EnvName, timestep))
|
|||
|
torch.save(self.q_critic.state_dict(),
|
|||
|
"./model/{}_q_critic{}.pth".format(EnvName, timestep))
|
|||
|
|
|||
|
def load(self, EnvName, timestep):
|
|||
|
self.actor.load_state_dict(torch.load(
|
|||
|
"./model/{}_actor{}.pth".format(EnvName, timestep), map_location=self.dvc))
|
|||
|
self.q_critic.load_state_dict(torch.load(
|
|||
|
"./model/{}_q_critic{}.pth".format(EnvName, timestep), map_location=self.dvc))
|
|||
|
|
|||
|
|
|||
|
class ReplayBuffer():
|
|||
|
def __init__(self, state_dim, action_dim, max_size, dvc):
|
|||
|
self.max_size = max_size
|
|||
|
self.dvc = dvc
|
|||
|
self.ptr = 0
|
|||
|
self.size = 0
|
|||
|
|
|||
|
self.s = torch.zeros((max_size, state_dim),
|
|||
|
dtype=torch.float, device=self.dvc)
|
|||
|
self.a = torch.zeros((max_size, action_dim),
|
|||
|
dtype=torch.float, device=self.dvc)
|
|||
|
self.r = torch.zeros((max_size, 1), dtype=torch.float, device=self.dvc)
|
|||
|
self.s_next = torch.zeros(
|
|||
|
(max_size, state_dim), dtype=torch.float, device=self.dvc)
|
|||
|
self.dw = torch.zeros((max_size, 1), dtype=torch.bool, device=self.dvc)
|
|||
|
|
|||
|
def add(self, s, a, r, s_next, dw):
|
|||
|
# 每次只放入一个时刻的数据
|
|||
|
self.s[self.ptr] = torch.from_numpy(s).to(self.dvc)
|
|||
|
self.a[self.ptr] = torch.from_numpy(a).to(
|
|||
|
self.dvc) # Note that a is numpy.array
|
|||
|
self.r[self.ptr] = r
|
|||
|
self.s_next[self.ptr] = torch.from_numpy(s_next).to(self.dvc)
|
|||
|
self.dw[self.ptr] = dw
|
|||
|
|
|||
|
self.ptr = (self.ptr + 1) % self.max_size # 存满了又重头开始存
|
|||
|
self.size = min(self.size + 1, self.max_size)
|
|||
|
|
|||
|
def sample(self, batch_size):
|
|||
|
ind = torch.randint(0, self.size, device=self.dvc, size=(batch_size,))
|
|||
|
return self.s[ind], self.a[ind], self.r[ind], self.s_next[ind], self.dw[ind]
|