SAC
This commit is contained in:
parent
67c7a9d6c7
commit
5b468deb9d
@ -137,8 +137,8 @@ class PPO_agent(object):
|
|||||||
torch.save(self.critic.state_dict(), "./weights/{}_q_critic{}.pth".format(EnvName,timestep))
|
torch.save(self.critic.state_dict(), "./weights/{}_q_critic{}.pth".format(EnvName,timestep))
|
||||||
|
|
||||||
def load(self,EnvName, timestep):
|
def load(self,EnvName, timestep):
|
||||||
self.actor.load_state_dict(torch.load("./model/{}_actor{}.pth".format(EnvName, timestep), map_location=self.dvc))
|
self.actor.load_state_dict(torch.load("./weights/{}_actor{}.pth".format(EnvName, timestep), map_location=self.dvc))
|
||||||
self.critic.load_state_dict(torch.load("./model/{}_q_critic{}.pth".format(EnvName, timestep), map_location=self.dvc))
|
self.critic.load_state_dict(torch.load("./weights/{}_q_critic{}.pth".format(EnvName, timestep), map_location=self.dvc))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -17,18 +17,18 @@ parser = argparse.ArgumentParser()
|
|||||||
parser.add_argument('--dvc', type=str, default='cpu',
|
parser.add_argument('--dvc', type=str, default='cpu',
|
||||||
help='running device: cuda or cpu')
|
help='running device: cuda or cpu')
|
||||||
parser.add_argument('--EnvIdex', type=int, default=0,
|
parser.add_argument('--EnvIdex', type=int, default=0,
|
||||||
help='PV1, Lch_Cv2, Humanv4, HCv4, BWv3, BWHv3')
|
help='PM_PPO_Con, PV1, Lch_Cv2, Humanv4, HCv4, BWv3, BWHv3')
|
||||||
parser.add_argument('--write', type=str2bool, default=False,
|
parser.add_argument('--write', type=str2bool, default=True,
|
||||||
help='Use SummaryWriter to record the training')
|
help='Use SummaryWriter to record the training')
|
||||||
parser.add_argument('--render', type=str2bool,
|
parser.add_argument('--render', type=str2bool,
|
||||||
default=False, help='Render or Not')
|
default=False, help='Render or Not')
|
||||||
parser.add_argument('--Loadmodel', type=str2bool,
|
parser.add_argument('--Loadmodel', type=str2bool,
|
||||||
default=False, help='Load pretrained model or Not')
|
default=False, help='Load pretrained model or Not')
|
||||||
parser.add_argument('--ModelIdex', type=int, default=100,
|
parser.add_argument('--ModelIdex', type=int, default=500,
|
||||||
help='which model to load')
|
help='which model to load')
|
||||||
|
|
||||||
parser.add_argument('--seed', type=int, default=0, help='random seed')
|
parser.add_argument('--seed', type=int, default=0, help='random seed')
|
||||||
parser.add_argument('--T_horizon', type=int, default=200,
|
parser.add_argument('--T_horizon', type=int, default=2000,
|
||||||
help='lenth of long trajectory')
|
help='lenth of long trajectory')
|
||||||
parser.add_argument('--Distribution', type=str, default='Beta',
|
parser.add_argument('--Distribution', type=str, default='Beta',
|
||||||
help='Should be one of Beta ; GS_ms ; GS_m')
|
help='Should be one of Beta ; GS_ms ; GS_m')
|
||||||
@ -97,7 +97,7 @@ def main():
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
timenow = str(datetime.now())[0:-10]
|
timenow = str(datetime.now())[0:-10]
|
||||||
timenow = ' ' + timenow[0:13] + '_' + timenow[-2::]
|
timenow = ' ' + timenow[0:13] + '_' + timenow[-2::]
|
||||||
writepath = 'runs/{}'.format(BrifEnvName[opt.EnvIdex]) + timenow
|
writepath = 'logs/{}'.format(BrifEnvName[opt.EnvIdex]) + timenow
|
||||||
if os.path.exists(writepath):
|
if os.path.exists(writepath):
|
||||||
shutil.rmtree(writepath)
|
shutil.rmtree(writepath)
|
||||||
writer = SummaryWriter(log_dir=writepath)
|
writer = SummaryWriter(log_dir=writepath)
|
||||||
|
144
SAC_Continuous/SAC.py
Normal file
144
SAC_Continuous/SAC.py
Normal file
@ -0,0 +1,144 @@
|
|||||||
|
from utils import Actor, Double_Q_Critic
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import copy
|
||||||
|
|
||||||
|
|
||||||
|
class SAC_countinuous():
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
# Init hyperparameters for agent, just like "self.gamma = opt.gamma, self.lambd = opt.lambd, ..."
|
||||||
|
self.__dict__.update(kwargs)
|
||||||
|
self.tau = 0.005
|
||||||
|
|
||||||
|
self.actor = Actor(self.state_dim, self.action_dim,
|
||||||
|
(self.net_width, self.net_width)).to(self.dvc)
|
||||||
|
self.actor_optimizer = torch.optim.Adam(
|
||||||
|
self.actor.parameters(), lr=self.a_lr)
|
||||||
|
|
||||||
|
self.q_critic = Double_Q_Critic(
|
||||||
|
self.state_dim, self.action_dim, (self.net_width, self.net_width)).to(self.dvc)
|
||||||
|
self.q_critic_optimizer = torch.optim.Adam(
|
||||||
|
self.q_critic.parameters(), lr=self.c_lr)
|
||||||
|
self.q_critic_target = copy.deepcopy(self.q_critic)
|
||||||
|
# Freeze target networks with respect to optimizers (only update via polyak averaging)
|
||||||
|
for p in self.q_critic_target.parameters():
|
||||||
|
p.requires_grad = False
|
||||||
|
|
||||||
|
self.replay_buffer = ReplayBuffer(
|
||||||
|
self.state_dim, self.action_dim, max_size=int(1e6), dvc=self.dvc)
|
||||||
|
|
||||||
|
if self.adaptive_alpha:
|
||||||
|
# Target Entropy = −dim(A) (e.g. , -6 for HalfCheetah-v2) as given in the paper
|
||||||
|
self.target_entropy = torch.tensor(-self.action_dim,
|
||||||
|
dtype=float, requires_grad=True, device=self.dvc)
|
||||||
|
# We learn log_alpha instead of alpha to ensure alpha>0
|
||||||
|
self.log_alpha = torch.tensor(
|
||||||
|
np.log(self.alpha), dtype=float, requires_grad=True, device=self.dvc)
|
||||||
|
self.alpha_optim = torch.optim.Adam([self.log_alpha], lr=self.c_lr)
|
||||||
|
|
||||||
|
def select_action(self, state, deterministic):
|
||||||
|
# only used when interact with the env
|
||||||
|
with torch.no_grad():
|
||||||
|
state = torch.FloatTensor(state[np.newaxis, :]).to(self.dvc)
|
||||||
|
a, _ = self.actor(state, deterministic, with_logprob=False)
|
||||||
|
return a.cpu().numpy()[0]
|
||||||
|
|
||||||
|
def train(self,):
|
||||||
|
s, a, r, s_next, dw = self.replay_buffer.sample(self.batch_size)
|
||||||
|
|
||||||
|
# ----------------------------- ↓↓↓↓↓ Update Q Net ↓↓↓↓↓ ------------------------------#
|
||||||
|
with torch.no_grad():
|
||||||
|
a_next, log_pi_a_next = self.actor(
|
||||||
|
s_next, deterministic=False, with_logprob=True)
|
||||||
|
target_Q1, target_Q2 = self.q_critic_target(s_next, a_next)
|
||||||
|
target_Q = torch.min(target_Q1, target_Q2)
|
||||||
|
# Dead or Done is tackled by Randombuffer
|
||||||
|
target_Q = r + (~dw) * self.gamma * \
|
||||||
|
(target_Q - self.alpha * log_pi_a_next)
|
||||||
|
|
||||||
|
# Get current Q estimates
|
||||||
|
current_Q1, current_Q2 = self.q_critic(s, a)
|
||||||
|
|
||||||
|
q_loss = F.mse_loss(current_Q1, target_Q) + \
|
||||||
|
F.mse_loss(current_Q2, target_Q)
|
||||||
|
self.q_critic_optimizer.zero_grad()
|
||||||
|
q_loss.backward()
|
||||||
|
self.q_critic_optimizer.step()
|
||||||
|
|
||||||
|
# ----------------------------- ↓↓↓↓↓ Update Actor Net ↓↓↓↓↓ ------------------------------#
|
||||||
|
# Freeze critic so you don't waste computational effort computing gradients for them when update actor
|
||||||
|
for params in self.q_critic.parameters():
|
||||||
|
params.requires_grad = False
|
||||||
|
|
||||||
|
a, log_pi_a = self.actor(s, deterministic=False, with_logprob=True)
|
||||||
|
current_Q1, current_Q2 = self.q_critic(s, a)
|
||||||
|
Q = torch.min(current_Q1, current_Q2)
|
||||||
|
|
||||||
|
a_loss = (self.alpha * log_pi_a - Q).mean()
|
||||||
|
self.actor_optimizer.zero_grad()
|
||||||
|
a_loss.backward()
|
||||||
|
self.actor_optimizer.step()
|
||||||
|
|
||||||
|
for params in self.q_critic.parameters():
|
||||||
|
params.requires_grad = True
|
||||||
|
|
||||||
|
# ----------------------------- ↓↓↓↓↓ Update alpha ↓↓↓↓↓ ------------------------------#
|
||||||
|
if self.adaptive_alpha:
|
||||||
|
# We learn log_alpha instead of alpha to ensure alpha>0
|
||||||
|
alpha_loss = -(self.log_alpha * (log_pi_a +
|
||||||
|
self.target_entropy).detach()).mean()
|
||||||
|
self.alpha_optim.zero_grad()
|
||||||
|
alpha_loss.backward()
|
||||||
|
self.alpha_optim.step()
|
||||||
|
self.alpha = self.log_alpha.exp()
|
||||||
|
|
||||||
|
# ----------------------------- ↓↓↓↓↓ Update Target Net ↓↓↓↓↓ ------------------------------#
|
||||||
|
for param, target_param in zip(self.q_critic.parameters(), self.q_critic_target.parameters()):
|
||||||
|
target_param.data.copy_(
|
||||||
|
self.tau * param.data + (1 - self.tau) * target_param.data)
|
||||||
|
|
||||||
|
def save(self, EnvName, timestep):
|
||||||
|
torch.save(self.actor.state_dict(),
|
||||||
|
"./model/{}_actor{}.pth".format(EnvName, timestep))
|
||||||
|
torch.save(self.q_critic.state_dict(),
|
||||||
|
"./model/{}_q_critic{}.pth".format(EnvName, timestep))
|
||||||
|
|
||||||
|
def load(self, EnvName, timestep):
|
||||||
|
self.actor.load_state_dict(torch.load(
|
||||||
|
"./model/{}_actor{}.pth".format(EnvName, timestep), map_location=self.dvc))
|
||||||
|
self.q_critic.load_state_dict(torch.load(
|
||||||
|
"./model/{}_q_critic{}.pth".format(EnvName, timestep), map_location=self.dvc))
|
||||||
|
|
||||||
|
|
||||||
|
class ReplayBuffer():
|
||||||
|
def __init__(self, state_dim, action_dim, max_size, dvc):
|
||||||
|
self.max_size = max_size
|
||||||
|
self.dvc = dvc
|
||||||
|
self.ptr = 0
|
||||||
|
self.size = 0
|
||||||
|
|
||||||
|
self.s = torch.zeros((max_size, state_dim),
|
||||||
|
dtype=torch.float, device=self.dvc)
|
||||||
|
self.a = torch.zeros((max_size, action_dim),
|
||||||
|
dtype=torch.float, device=self.dvc)
|
||||||
|
self.r = torch.zeros((max_size, 1), dtype=torch.float, device=self.dvc)
|
||||||
|
self.s_next = torch.zeros(
|
||||||
|
(max_size, state_dim), dtype=torch.float, device=self.dvc)
|
||||||
|
self.dw = torch.zeros((max_size, 1), dtype=torch.bool, device=self.dvc)
|
||||||
|
|
||||||
|
def add(self, s, a, r, s_next, dw):
|
||||||
|
# 每次只放入一个时刻的数据
|
||||||
|
self.s[self.ptr] = torch.from_numpy(s).to(self.dvc)
|
||||||
|
self.a[self.ptr] = torch.from_numpy(a).to(
|
||||||
|
self.dvc) # Note that a is numpy.array
|
||||||
|
self.r[self.ptr] = r
|
||||||
|
self.s_next[self.ptr] = torch.from_numpy(s_next).to(self.dvc)
|
||||||
|
self.dw[self.ptr] = dw
|
||||||
|
|
||||||
|
self.ptr = (self.ptr + 1) % self.max_size # 存满了又重头开始存
|
||||||
|
self.size = min(self.size + 1, self.max_size)
|
||||||
|
|
||||||
|
def sample(self, batch_size):
|
||||||
|
ind = torch.randint(0, self.size, device=self.dvc, size=(batch_size,))
|
||||||
|
return self.s[ind], self.a[ind], self.r[ind], self.s_next[ind], self.dw[ind]
|
155
SAC_Continuous/main.py
Normal file
155
SAC_Continuous/main.py
Normal file
@ -0,0 +1,155 @@
|
|||||||
|
from utils import str2bool, evaluate_policy, Action_adapter, Action_adapter_reverse, Reward_adapter
|
||||||
|
from datetime import datetime
|
||||||
|
from SAC import SAC_countinuous
|
||||||
|
import gymnasium as gym
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import argparse
|
||||||
|
import torch
|
||||||
|
# fmt: off
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
from env import PartitionMazeEnv
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
'''Hyperparameter Setting'''
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--dvc', type=str, default='cpu',
|
||||||
|
help='running device: cuda or cpu')
|
||||||
|
parser.add_argument('--EnvIdex', type=int, default=0,
|
||||||
|
help='PV1, Lch_Cv2, Humanv4, HCv4, BWv3, BWHv3')
|
||||||
|
parser.add_argument('--write', type=str2bool, default=True,
|
||||||
|
help='Use SummaryWriter to record the training')
|
||||||
|
parser.add_argument('--render', type=str2bool,
|
||||||
|
default=False, help='Render or Not')
|
||||||
|
parser.add_argument('--Loadmodel', type=str2bool,
|
||||||
|
default=False, help='Load pretrained model or Not')
|
||||||
|
parser.add_argument('--ModelIdex', type=int, default=100,
|
||||||
|
help='which model to load')
|
||||||
|
|
||||||
|
parser.add_argument('--seed', type=int, default=0, help='random seed')
|
||||||
|
parser.add_argument('--Max_train_steps', type=int,
|
||||||
|
default=int(5e8), help='Max training steps')
|
||||||
|
parser.add_argument('--save_interval', type=int,
|
||||||
|
default=int(5e5), help='Model saving interval, in steps.')
|
||||||
|
parser.add_argument('--eval_interval', type=int, default=int(5e3),
|
||||||
|
help='Model evaluating interval, in steps.')
|
||||||
|
parser.add_argument('--update_every', type=int, default=50,
|
||||||
|
help='Training Fraquency, in stpes')
|
||||||
|
|
||||||
|
parser.add_argument('--gamma', type=float, default=0.99,
|
||||||
|
help='Discounted Factor')
|
||||||
|
parser.add_argument('--net_width', type=int, default=256,
|
||||||
|
help='Hidden net width, s_dim-400-300-a_dim')
|
||||||
|
parser.add_argument('--a_lr', type=float, default=3e-4,
|
||||||
|
help='Learning rate of actor')
|
||||||
|
parser.add_argument('--c_lr', type=float, default=3e-4,
|
||||||
|
help='Learning rate of critic')
|
||||||
|
parser.add_argument('--batch_size', type=int, default=256,
|
||||||
|
help='batch_size of training')
|
||||||
|
parser.add_argument('--alpha', type=float, default=0.12,
|
||||||
|
help='Entropy coefficient')
|
||||||
|
parser.add_argument('--adaptive_alpha', type=str2bool,
|
||||||
|
default=True, help='Use adaptive_alpha or Not')
|
||||||
|
opt = parser.parse_args()
|
||||||
|
opt.dvc = torch.device(opt.dvc) # from str to torch.device
|
||||||
|
print(opt)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
EnvName = ['PartitionMaze_SAC_Continuous', 'Pendulum-v1', 'LunarLanderContinuous-v2',
|
||||||
|
'Humanoid-v4', 'HalfCheetah-v4', 'BipedalWalker-v3', 'BipedalWalkerHardcore-v3']
|
||||||
|
BrifEnvName = ['PM_SAC_Con', 'PV1', 'LLdV2',
|
||||||
|
'Humanv4', 'HCv4', 'BWv3', 'BWHv3']
|
||||||
|
|
||||||
|
# Build Env
|
||||||
|
# env = gym.make(EnvName[opt.EnvIdex],
|
||||||
|
# render_mode="human" if opt.render else None)
|
||||||
|
# eval_env = gym.make(EnvName[opt.EnvIdex])
|
||||||
|
env = PartitionMazeEnv()
|
||||||
|
eval_env = PartitionMazeEnv()
|
||||||
|
opt.state_dim = env.observation_space.shape[0]
|
||||||
|
opt.action_dim = env.action_space.shape[0]
|
||||||
|
# remark: action space【-max,max】
|
||||||
|
opt.max_action = float(env.action_space.high[0])
|
||||||
|
opt.max_e_steps = 100
|
||||||
|
print(f'Env:{EnvName[opt.EnvIdex]} state_dim:{opt.state_dim} action_dim:{opt.action_dim} '
|
||||||
|
f'max_a:{opt.max_action} min_a:{env.action_space.low[0]}')
|
||||||
|
|
||||||
|
# Seed Everything
|
||||||
|
env_seed = opt.seed
|
||||||
|
torch.manual_seed(opt.seed)
|
||||||
|
torch.cuda.manual_seed(opt.seed)
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
torch.backends.cudnn.benchmark = False
|
||||||
|
print("Random Seed: {}".format(opt.seed))
|
||||||
|
|
||||||
|
# Build SummaryWriter to record training curves
|
||||||
|
if opt.write:
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
timenow = str(datetime.now())[0:-10]
|
||||||
|
timenow = ' ' + timenow[0:13] + '_' + timenow[-2::]
|
||||||
|
writepath = 'logs/{}'.format(BrifEnvName[opt.EnvIdex]) + timenow
|
||||||
|
if os.path.exists(writepath):
|
||||||
|
shutil.rmtree(writepath)
|
||||||
|
writer = SummaryWriter(log_dir=writepath)
|
||||||
|
|
||||||
|
# var: transfer argparse to dictionary
|
||||||
|
agent = SAC_countinuous(**vars(opt))
|
||||||
|
if opt.Loadmodel:
|
||||||
|
agent.load(BrifEnvName[opt.EnvIdex], opt.ModelIdex)
|
||||||
|
|
||||||
|
if opt.render:
|
||||||
|
while True:
|
||||||
|
score = evaluate_policy(env, agent, turns=1)
|
||||||
|
print('EnvName:', BrifEnvName[opt.EnvIdex], 'score:', score)
|
||||||
|
else:
|
||||||
|
total_steps = 0
|
||||||
|
while total_steps < opt.Max_train_steps:
|
||||||
|
# Do not use opt.seed directly, or it can overfit to opt.seed
|
||||||
|
s = env.reset(seed=env_seed)
|
||||||
|
env_seed += 1
|
||||||
|
done = False
|
||||||
|
|
||||||
|
'''Interact & trian'''
|
||||||
|
while not done:
|
||||||
|
if total_steps < (5*opt.max_e_steps):
|
||||||
|
a = env.action_space.sample() # act∈[-max,max]
|
||||||
|
# a = Action_adapter_reverse(act, opt.max_action) # a∈[-1,1]
|
||||||
|
else:
|
||||||
|
a = agent.select_action(s, deterministic=False) # a∈[-1,1]
|
||||||
|
a = Action_adapter(a, opt.max_action) # act∈[-max,max]
|
||||||
|
s_next, r, dw, tr, info = env.step(
|
||||||
|
a) # dw: dead&win; tr: truncated
|
||||||
|
r = Reward_adapter(r, opt.EnvIdex)
|
||||||
|
done = (dw or tr)
|
||||||
|
|
||||||
|
agent.replay_buffer.add(s, a, r, s_next, dw)
|
||||||
|
s = s_next
|
||||||
|
total_steps += 1
|
||||||
|
|
||||||
|
'''train if it's time'''
|
||||||
|
# train 50 times every 50 steps rather than 1 training per step. Better!
|
||||||
|
if (total_steps >= 2*opt.max_e_steps) and (total_steps % opt.update_every == 0):
|
||||||
|
for j in range(opt.update_every):
|
||||||
|
agent.train()
|
||||||
|
|
||||||
|
'''record & log'''
|
||||||
|
if total_steps % opt.eval_interval == 0:
|
||||||
|
ep_r = evaluate_policy(eval_env, agent, turns=3)
|
||||||
|
if opt.write:
|
||||||
|
writer.add_scalar(
|
||||||
|
'ep_r', ep_r, global_step=total_steps)
|
||||||
|
print(
|
||||||
|
f'EnvName:{BrifEnvName[opt.EnvIdex]}, Steps: {int(total_steps/1000)}k, Episode Reward:{ep_r}')
|
||||||
|
|
||||||
|
'''save model'''
|
||||||
|
if total_steps % opt.save_interval == 0:
|
||||||
|
agent.save(BrifEnvName[opt.EnvIdex], int(total_steps/1000))
|
||||||
|
env.close()
|
||||||
|
eval_env.close()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
133
SAC_Continuous/utils.py
Normal file
133
SAC_Continuous/utils.py
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
import torch
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.distributions import Normal
|
||||||
|
|
||||||
|
|
||||||
|
def build_net(layer_shape, hidden_activation, output_activation):
|
||||||
|
'''Build net with for loop'''
|
||||||
|
layers = []
|
||||||
|
for j in range(len(layer_shape)-1):
|
||||||
|
act = hidden_activation if j < len(
|
||||||
|
layer_shape)-2 else output_activation
|
||||||
|
layers += [nn.Linear(layer_shape[j], layer_shape[j+1]), act()]
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
|
||||||
|
class Actor(nn.Module):
|
||||||
|
def __init__(self, state_dim, action_dim, hid_shape, hidden_activation=nn.ReLU, output_activation=nn.ReLU):
|
||||||
|
super(Actor, self).__init__()
|
||||||
|
layers = [state_dim] + list(hid_shape)
|
||||||
|
|
||||||
|
self.a_net = build_net(layers, hidden_activation, output_activation)
|
||||||
|
self.mu_layer = nn.Linear(layers[-1], action_dim)
|
||||||
|
self.log_std_layer = nn.Linear(layers[-1], action_dim)
|
||||||
|
|
||||||
|
self.LOG_STD_MAX = 2
|
||||||
|
self.LOG_STD_MIN = -20
|
||||||
|
|
||||||
|
def forward(self, state, deterministic, with_logprob):
|
||||||
|
'''Network with Enforcing Action Bounds'''
|
||||||
|
net_out = self.a_net(state)
|
||||||
|
mu = self.mu_layer(net_out)
|
||||||
|
log_std = self.log_std_layer(net_out)
|
||||||
|
log_std = torch.clamp(log_std, self.LOG_STD_MIN,
|
||||||
|
self.LOG_STD_MAX) # 总感觉这里clamp不利于学习
|
||||||
|
# we learn log_std rather than std, so that exp(log_std) is always > 0
|
||||||
|
std = torch.exp(log_std)
|
||||||
|
dist = Normal(mu, std)
|
||||||
|
if deterministic:
|
||||||
|
u = mu
|
||||||
|
else:
|
||||||
|
u = dist.rsample()
|
||||||
|
|
||||||
|
'''↓↓↓ Enforcing Action Bounds, see Page 16 of https://arxiv.org/pdf/1812.05905.pdf ↓↓↓'''
|
||||||
|
a = torch.tanh(u)
|
||||||
|
if with_logprob:
|
||||||
|
# Get probability density of logp_pi_a from probability density of u:
|
||||||
|
# logp_pi_a = (dist.log_prob(u) - torch.log(1 - a.pow(2) + 1e-6)).sum(dim=1, keepdim=True)
|
||||||
|
# Derive from the above equation. No a, thus no tanh(h), thus less gradient vanish and more stable.
|
||||||
|
logp_pi_a = dist.log_prob(u).sum(axis=1, keepdim=True) - (
|
||||||
|
2 * (np.log(2) - u - F.softplus(-2 * u))).sum(axis=1, keepdim=True)
|
||||||
|
else:
|
||||||
|
logp_pi_a = None
|
||||||
|
|
||||||
|
return a, logp_pi_a
|
||||||
|
|
||||||
|
|
||||||
|
class Double_Q_Critic(nn.Module):
|
||||||
|
def __init__(self, state_dim, action_dim, hid_shape):
|
||||||
|
super(Double_Q_Critic, self).__init__()
|
||||||
|
layers = [state_dim + action_dim] + list(hid_shape) + [1]
|
||||||
|
|
||||||
|
self.Q_1 = build_net(layers, nn.ReLU, nn.Identity)
|
||||||
|
self.Q_2 = build_net(layers, nn.ReLU, nn.Identity)
|
||||||
|
|
||||||
|
def forward(self, state, action):
|
||||||
|
sa = torch.cat([state, action], 1)
|
||||||
|
q1 = self.Q_1(sa)
|
||||||
|
q2 = self.Q_2(sa)
|
||||||
|
return q1, q2
|
||||||
|
|
||||||
|
# reward engineering for better training
|
||||||
|
|
||||||
|
|
||||||
|
def Reward_adapter(r, EnvIdex):
|
||||||
|
# For Pendulum-v0
|
||||||
|
if EnvIdex == 0:
|
||||||
|
r = (r + 8) / 8
|
||||||
|
|
||||||
|
# For LunarLander
|
||||||
|
elif EnvIdex == 1:
|
||||||
|
if r <= -100:
|
||||||
|
r = -10
|
||||||
|
|
||||||
|
# For BipedalWalker
|
||||||
|
elif EnvIdex == 4 or EnvIdex == 5:
|
||||||
|
if r <= -100:
|
||||||
|
r = -1
|
||||||
|
return r
|
||||||
|
|
||||||
|
|
||||||
|
def Action_adapter(a, max_action):
|
||||||
|
# from [-1,1] to [0,1]
|
||||||
|
return (a + 1) / 2
|
||||||
|
|
||||||
|
|
||||||
|
def Action_adapter_reverse(act, max_action):
|
||||||
|
# from [-max,max] to [-1,1]
|
||||||
|
return act/max_action
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_policy(env, agent, turns=3):
|
||||||
|
total_scores = 0
|
||||||
|
for j in range(turns):
|
||||||
|
s = env.reset()
|
||||||
|
done = False
|
||||||
|
action_series = []
|
||||||
|
while not done:
|
||||||
|
# Take deterministic actions at test time
|
||||||
|
a = agent.select_action(s, deterministic=True)
|
||||||
|
a = Action_adapter(a, 1)
|
||||||
|
s_next, r, dw, tr, info = env.step(a)
|
||||||
|
done = (dw or tr)
|
||||||
|
action_series.append(a[0])
|
||||||
|
total_scores += r
|
||||||
|
s = s_next
|
||||||
|
print('action_series:', np.round(action_series, 3))
|
||||||
|
print('state:', s)
|
||||||
|
return total_scores/turns
|
||||||
|
|
||||||
|
|
||||||
|
def str2bool(v):
|
||||||
|
'''transfer str to bool for argparse'''
|
||||||
|
if isinstance(v, bool):
|
||||||
|
return v
|
||||||
|
if v.lower() in ('yes', 'True', 'true', 'TRUE', 't', 'y', '1'):
|
||||||
|
return True
|
||||||
|
elif v.lower() in ('no', 'False', 'false', 'FALSE', 'f', 'n', '0'):
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
raise argparse.ArgumentTypeError('Boolean value expected.')
|
45
env.py
45
env.py
@ -40,7 +40,7 @@ class PartitionMazeEnv(gym.Env):
|
|||||||
# 可能需要手动修改的超参数
|
# 可能需要手动修改的超参数
|
||||||
##############################
|
##############################
|
||||||
self.CUT_NUM = 4 # 横切一半,竖切一半
|
self.CUT_NUM = 4 # 横切一半,竖切一半
|
||||||
self.BASE_LINE = 4000 # 基准时间,通过greedy或者蒙特卡洛计算出来
|
self.BASE_LINE = 3500 # 基准时间,通过greedy或者蒙特卡洛计算出来
|
||||||
self.MAX_STEPS = 50 # 迷宫走法步数上限
|
self.MAX_STEPS = 50 # 迷宫走法步数上限
|
||||||
|
|
||||||
self.phase = 0 # 阶段控制,0:区域划分阶段,1:迷宫初始化阶段,2:走迷宫阶段
|
self.phase = 0 # 阶段控制,0:区域划分阶段,1:迷宫初始化阶段,2:走迷宫阶段
|
||||||
@ -159,7 +159,7 @@ class PartitionMazeEnv(gym.Env):
|
|||||||
break
|
break
|
||||||
|
|
||||||
if not valid_partition:
|
if not valid_partition:
|
||||||
reward = -10000
|
reward = -10
|
||||||
# 状态:前 4 维为 partition_values,其余为区域访问状态(初始全0)
|
# 状态:前 4 维为 partition_values,其余为区域访问状态(初始全0)
|
||||||
max_regions = (self.CUT_NUM // 2 + 1) ** 2
|
max_regions = (self.CUT_NUM // 2 + 1) ** 2
|
||||||
state = np.concatenate([
|
state = np.concatenate([
|
||||||
@ -170,10 +170,11 @@ class PartitionMazeEnv(gym.Env):
|
|||||||
else:
|
else:
|
||||||
# 进入阶段 1:初始化迷宫
|
# 进入阶段 1:初始化迷宫
|
||||||
self.phase = 1
|
self.phase = 1
|
||||||
reward = 100
|
reward = 0.2
|
||||||
|
|
||||||
# 构建反向索引,方便后续计算
|
# 构建反向索引,方便后续计算
|
||||||
self.reverse_rectangles = {v['center']: k for k, v in self.rectangles.items()}
|
self.reverse_rectangles = {
|
||||||
|
v['center']: k for k, v in self.rectangles.items()}
|
||||||
|
|
||||||
region_centers = [
|
region_centers = [
|
||||||
(i, j, self.rectangles[(i, j)]['center'])
|
(i, j, self.rectangles[(i, j)]['center'])
|
||||||
@ -203,10 +204,12 @@ class PartitionMazeEnv(gym.Env):
|
|||||||
for i in range(len(self.row_cuts) - 1):
|
for i in range(len(self.row_cuts) - 1):
|
||||||
for j in range(len(self.col_cuts) - 1):
|
for j in range(len(self.col_cuts) - 1):
|
||||||
idx = i * (len(self.col_cuts) - 1) + j
|
idx = i * (len(self.col_cuts) - 1) + j
|
||||||
visit_status[idx] = float(self.rectangles[(i, j)]['is_visited'])
|
visit_status[idx] = float(
|
||||||
|
self.rectangles[(i, j)]['is_visited'])
|
||||||
for i in range(idx + 1, max_regions):
|
for i in range(idx + 1, max_regions):
|
||||||
visit_status[i] = 100
|
visit_status[i] = 100
|
||||||
state = np.concatenate([self.partition_values, visit_status])
|
state = np.concatenate(
|
||||||
|
[self.partition_values, visit_status])
|
||||||
return state, reward, False, False, {}
|
return state, reward, False, False, {}
|
||||||
|
|
||||||
elif self.phase == 2:
|
elif self.phase == 2:
|
||||||
@ -258,18 +261,29 @@ class PartitionMazeEnv(gym.Env):
|
|||||||
# reward -= 10
|
# reward -= 10
|
||||||
# 如果移动不合法,或者动作为stay,则保持原位置
|
# 如果移动不合法,或者动作为stay,则保持原位置
|
||||||
|
|
||||||
|
# 检查是否移动
|
||||||
|
car_moved = (new_row != current_row or new_col != current_col)
|
||||||
# 更新车辆位置
|
# 更新车辆位置
|
||||||
self.car_pos[current_car] = self.rectangles[(
|
self.car_pos[current_car] = self.rectangles[(
|
||||||
new_row, new_col)]['center']
|
new_row, new_col)]['center']
|
||||||
if new_row != current_row or new_col != current_col:
|
if car_moved:
|
||||||
self.car_traj[current_car].append((new_row, new_col))
|
self.car_traj[current_car].append((new_row, new_col))
|
||||||
|
# 更新访问标记:将新网格标记为已访问
|
||||||
|
self.rectangles[(new_row, new_col)]['is_visited'] = True
|
||||||
|
|
||||||
|
# 记录所有车辆一轮中是否移动
|
||||||
|
if self.current_car_index == 0:
|
||||||
|
# 新一轮的开始,初始化移动标记
|
||||||
|
self.cars_moved = [False] * self.num_cars
|
||||||
|
self.cars_moved[current_car] = car_moved
|
||||||
|
# 如果一轮结束,检查是否所有车辆都没有移动
|
||||||
|
if self.current_car_index == (self.num_cars - 1) and not any(self.cars_moved):
|
||||||
|
reward -= 0.01
|
||||||
|
|
||||||
self.step_count += 1
|
self.step_count += 1
|
||||||
self.current_car_index = (
|
self.current_car_index = (
|
||||||
self.current_car_index + 1) % self.num_cars
|
self.current_car_index + 1) % self.num_cars
|
||||||
|
|
||||||
# 更新访问标记:将新网格标记为已访问
|
|
||||||
self.rectangles[(new_row, new_col)]['is_visited'] = True
|
|
||||||
|
|
||||||
# 观察状态
|
# 观察状态
|
||||||
# 构造访问状态向量
|
# 构造访问状态向量
|
||||||
max_regions = (self.CUT_NUM // 2 + 1) ** 2
|
max_regions = (self.CUT_NUM // 2 + 1) ** 2
|
||||||
@ -279,7 +293,8 @@ class PartitionMazeEnv(gym.Env):
|
|||||||
for i in range(len(self.row_cuts) - 1):
|
for i in range(len(self.row_cuts) - 1):
|
||||||
for j in range(len(self.col_cuts) - 1):
|
for j in range(len(self.col_cuts) - 1):
|
||||||
idx = i * (len(self.col_cuts) - 1) + j
|
idx = i * (len(self.col_cuts) - 1) + j
|
||||||
visit_status[idx] = float(self.rectangles[(i, j)]['is_visited'])
|
visit_status[idx] = float(
|
||||||
|
self.rectangles[(i, j)]['is_visited'])
|
||||||
for i in range(idx + 1, max_regions):
|
for i in range(idx + 1, max_regions):
|
||||||
visit_status[i] = 100
|
visit_status[i] = 100
|
||||||
state = np.concatenate([self.partition_values, visit_status])
|
state = np.concatenate([self.partition_values, visit_status])
|
||||||
@ -293,11 +308,7 @@ class PartitionMazeEnv(gym.Env):
|
|||||||
for idx in range(self.num_cars)])
|
for idx in range(self.num_cars)])
|
||||||
# TODO 让奖励在baseline附近变化更剧烈
|
# TODO 让奖励在baseline附近变化更剧烈
|
||||||
# reward = math.exp(-T / self.BASE_LINE) * 1000
|
# reward = math.exp(-T / self.BASE_LINE) * 1000
|
||||||
reward = self.BASE_LINE / T * 1000
|
reward += self.BASE_LINE / T
|
||||||
if T < self.BASE_LINE:
|
|
||||||
reward *= 10
|
|
||||||
print(reward)
|
|
||||||
|
|
||||||
|
|
||||||
# if reward > self.BASE_LINE:
|
# if reward > self.BASE_LINE:
|
||||||
# reward -= 200
|
# reward -= 200
|
||||||
@ -305,7 +316,7 @@ class PartitionMazeEnv(gym.Env):
|
|||||||
# reward -= 10 * self.step_count
|
# reward -= 10 * self.step_count
|
||||||
# TODO 动态调整baseline
|
# TODO 动态调整baseline
|
||||||
elif done and self.step_count >= self.MAX_STEPS:
|
elif done and self.step_count >= self.MAX_STEPS:
|
||||||
reward += -1000
|
reward += -0.8
|
||||||
|
|
||||||
return state, reward, done, False, {}
|
return state, reward, done, False, {}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user