From 5b468deb9d882e18dc92fda5e1cd2c3a3814bdd1 Mon Sep 17 00:00:00 2001 From: weixin_46229132 Date: Fri, 21 Mar 2025 16:04:42 +0800 Subject: [PATCH] SAC --- PPO_Continuous/PPO.py | 4 +- PPO_Continuous/main.py | 10 +-- SAC_Continuous/SAC.py | 144 +++++++++++++++++++++++++++++++++++++ SAC_Continuous/main.py | 155 ++++++++++++++++++++++++++++++++++++++++ SAC_Continuous/utils.py | 133 ++++++++++++++++++++++++++++++++++ env.py | 67 +++++++++-------- 6 files changed, 478 insertions(+), 35 deletions(-) create mode 100644 SAC_Continuous/SAC.py create mode 100644 SAC_Continuous/main.py create mode 100644 SAC_Continuous/utils.py diff --git a/PPO_Continuous/PPO.py b/PPO_Continuous/PPO.py index 276a1bf..4bd68ec 100644 --- a/PPO_Continuous/PPO.py +++ b/PPO_Continuous/PPO.py @@ -137,8 +137,8 @@ class PPO_agent(object): torch.save(self.critic.state_dict(), "./weights/{}_q_critic{}.pth".format(EnvName,timestep)) def load(self,EnvName, timestep): - self.actor.load_state_dict(torch.load("./model/{}_actor{}.pth".format(EnvName, timestep), map_location=self.dvc)) - self.critic.load_state_dict(torch.load("./model/{}_q_critic{}.pth".format(EnvName, timestep), map_location=self.dvc)) + self.actor.load_state_dict(torch.load("./weights/{}_actor{}.pth".format(EnvName, timestep), map_location=self.dvc)) + self.critic.load_state_dict(torch.load("./weights/{}_q_critic{}.pth".format(EnvName, timestep), map_location=self.dvc)) diff --git a/PPO_Continuous/main.py b/PPO_Continuous/main.py index d7d8e87..4ba824e 100644 --- a/PPO_Continuous/main.py +++ b/PPO_Continuous/main.py @@ -17,18 +17,18 @@ parser = argparse.ArgumentParser() parser.add_argument('--dvc', type=str, default='cpu', help='running device: cuda or cpu') parser.add_argument('--EnvIdex', type=int, default=0, - help='PV1, Lch_Cv2, Humanv4, HCv4, BWv3, BWHv3') -parser.add_argument('--write', type=str2bool, default=False, + help='PM_PPO_Con, PV1, Lch_Cv2, Humanv4, HCv4, BWv3, BWHv3') +parser.add_argument('--write', type=str2bool, default=True, help='Use SummaryWriter to record the training') parser.add_argument('--render', type=str2bool, default=False, help='Render or Not') parser.add_argument('--Loadmodel', type=str2bool, default=False, help='Load pretrained model or Not') -parser.add_argument('--ModelIdex', type=int, default=100, +parser.add_argument('--ModelIdex', type=int, default=500, help='which model to load') parser.add_argument('--seed', type=int, default=0, help='random seed') -parser.add_argument('--T_horizon', type=int, default=200, +parser.add_argument('--T_horizon', type=int, default=2000, help='lenth of long trajectory') parser.add_argument('--Distribution', type=str, default='Beta', help='Should be one of Beta ; GS_ms ; GS_m') @@ -97,7 +97,7 @@ def main(): from torch.utils.tensorboard import SummaryWriter timenow = str(datetime.now())[0:-10] timenow = ' ' + timenow[0:13] + '_' + timenow[-2::] - writepath = 'runs/{}'.format(BrifEnvName[opt.EnvIdex]) + timenow + writepath = 'logs/{}'.format(BrifEnvName[opt.EnvIdex]) + timenow if os.path.exists(writepath): shutil.rmtree(writepath) writer = SummaryWriter(log_dir=writepath) diff --git a/SAC_Continuous/SAC.py b/SAC_Continuous/SAC.py new file mode 100644 index 0000000..3f326e5 --- /dev/null +++ b/SAC_Continuous/SAC.py @@ -0,0 +1,144 @@ +from utils import Actor, Double_Q_Critic +import torch.nn.functional as F +import numpy as np +import torch +import copy + + +class SAC_countinuous(): + def __init__(self, **kwargs): + # Init hyperparameters for agent, just like "self.gamma = opt.gamma, self.lambd = opt.lambd, ..." + self.__dict__.update(kwargs) + self.tau = 0.005 + + self.actor = Actor(self.state_dim, self.action_dim, + (self.net_width, self.net_width)).to(self.dvc) + self.actor_optimizer = torch.optim.Adam( + self.actor.parameters(), lr=self.a_lr) + + self.q_critic = Double_Q_Critic( + self.state_dim, self.action_dim, (self.net_width, self.net_width)).to(self.dvc) + self.q_critic_optimizer = torch.optim.Adam( + self.q_critic.parameters(), lr=self.c_lr) + self.q_critic_target = copy.deepcopy(self.q_critic) + # Freeze target networks with respect to optimizers (only update via polyak averaging) + for p in self.q_critic_target.parameters(): + p.requires_grad = False + + self.replay_buffer = ReplayBuffer( + self.state_dim, self.action_dim, max_size=int(1e6), dvc=self.dvc) + + if self.adaptive_alpha: + # Target Entropy = −dim(A) (e.g. , -6 for HalfCheetah-v2) as given in the paper + self.target_entropy = torch.tensor(-self.action_dim, + dtype=float, requires_grad=True, device=self.dvc) + # We learn log_alpha instead of alpha to ensure alpha>0 + self.log_alpha = torch.tensor( + np.log(self.alpha), dtype=float, requires_grad=True, device=self.dvc) + self.alpha_optim = torch.optim.Adam([self.log_alpha], lr=self.c_lr) + + def select_action(self, state, deterministic): + # only used when interact with the env + with torch.no_grad(): + state = torch.FloatTensor(state[np.newaxis, :]).to(self.dvc) + a, _ = self.actor(state, deterministic, with_logprob=False) + return a.cpu().numpy()[0] + + def train(self,): + s, a, r, s_next, dw = self.replay_buffer.sample(self.batch_size) + + # ----------------------------- ↓↓↓↓↓ Update Q Net ↓↓↓↓↓ ------------------------------# + with torch.no_grad(): + a_next, log_pi_a_next = self.actor( + s_next, deterministic=False, with_logprob=True) + target_Q1, target_Q2 = self.q_critic_target(s_next, a_next) + target_Q = torch.min(target_Q1, target_Q2) + # Dead or Done is tackled by Randombuffer + target_Q = r + (~dw) * self.gamma * \ + (target_Q - self.alpha * log_pi_a_next) + + # Get current Q estimates + current_Q1, current_Q2 = self.q_critic(s, a) + + q_loss = F.mse_loss(current_Q1, target_Q) + \ + F.mse_loss(current_Q2, target_Q) + self.q_critic_optimizer.zero_grad() + q_loss.backward() + self.q_critic_optimizer.step() + + # ----------------------------- ↓↓↓↓↓ Update Actor Net ↓↓↓↓↓ ------------------------------# + # Freeze critic so you don't waste computational effort computing gradients for them when update actor + for params in self.q_critic.parameters(): + params.requires_grad = False + + a, log_pi_a = self.actor(s, deterministic=False, with_logprob=True) + current_Q1, current_Q2 = self.q_critic(s, a) + Q = torch.min(current_Q1, current_Q2) + + a_loss = (self.alpha * log_pi_a - Q).mean() + self.actor_optimizer.zero_grad() + a_loss.backward() + self.actor_optimizer.step() + + for params in self.q_critic.parameters(): + params.requires_grad = True + + # ----------------------------- ↓↓↓↓↓ Update alpha ↓↓↓↓↓ ------------------------------# + if self.adaptive_alpha: + # We learn log_alpha instead of alpha to ensure alpha>0 + alpha_loss = -(self.log_alpha * (log_pi_a + + self.target_entropy).detach()).mean() + self.alpha_optim.zero_grad() + alpha_loss.backward() + self.alpha_optim.step() + self.alpha = self.log_alpha.exp() + + # ----------------------------- ↓↓↓↓↓ Update Target Net ↓↓↓↓↓ ------------------------------# + for param, target_param in zip(self.q_critic.parameters(), self.q_critic_target.parameters()): + target_param.data.copy_( + self.tau * param.data + (1 - self.tau) * target_param.data) + + def save(self, EnvName, timestep): + torch.save(self.actor.state_dict(), + "./model/{}_actor{}.pth".format(EnvName, timestep)) + torch.save(self.q_critic.state_dict(), + "./model/{}_q_critic{}.pth".format(EnvName, timestep)) + + def load(self, EnvName, timestep): + self.actor.load_state_dict(torch.load( + "./model/{}_actor{}.pth".format(EnvName, timestep), map_location=self.dvc)) + self.q_critic.load_state_dict(torch.load( + "./model/{}_q_critic{}.pth".format(EnvName, timestep), map_location=self.dvc)) + + +class ReplayBuffer(): + def __init__(self, state_dim, action_dim, max_size, dvc): + self.max_size = max_size + self.dvc = dvc + self.ptr = 0 + self.size = 0 + + self.s = torch.zeros((max_size, state_dim), + dtype=torch.float, device=self.dvc) + self.a = torch.zeros((max_size, action_dim), + dtype=torch.float, device=self.dvc) + self.r = torch.zeros((max_size, 1), dtype=torch.float, device=self.dvc) + self.s_next = torch.zeros( + (max_size, state_dim), dtype=torch.float, device=self.dvc) + self.dw = torch.zeros((max_size, 1), dtype=torch.bool, device=self.dvc) + + def add(self, s, a, r, s_next, dw): + # 每次只放入一个时刻的数据 + self.s[self.ptr] = torch.from_numpy(s).to(self.dvc) + self.a[self.ptr] = torch.from_numpy(a).to( + self.dvc) # Note that a is numpy.array + self.r[self.ptr] = r + self.s_next[self.ptr] = torch.from_numpy(s_next).to(self.dvc) + self.dw[self.ptr] = dw + + self.ptr = (self.ptr + 1) % self.max_size # 存满了又重头开始存 + self.size = min(self.size + 1, self.max_size) + + def sample(self, batch_size): + ind = torch.randint(0, self.size, device=self.dvc, size=(batch_size,)) + return self.s[ind], self.a[ind], self.r[ind], self.s_next[ind], self.dw[ind] diff --git a/SAC_Continuous/main.py b/SAC_Continuous/main.py new file mode 100644 index 0000000..6e9ad70 --- /dev/null +++ b/SAC_Continuous/main.py @@ -0,0 +1,155 @@ +from utils import str2bool, evaluate_policy, Action_adapter, Action_adapter_reverse, Reward_adapter +from datetime import datetime +from SAC import SAC_countinuous +import gymnasium as gym +import os +import shutil +import argparse +import torch +# fmt: off +import sys +import os +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from env import PartitionMazeEnv +# fmt: on + +'''Hyperparameter Setting''' +parser = argparse.ArgumentParser() +parser.add_argument('--dvc', type=str, default='cpu', + help='running device: cuda or cpu') +parser.add_argument('--EnvIdex', type=int, default=0, + help='PV1, Lch_Cv2, Humanv4, HCv4, BWv3, BWHv3') +parser.add_argument('--write', type=str2bool, default=True, + help='Use SummaryWriter to record the training') +parser.add_argument('--render', type=str2bool, + default=False, help='Render or Not') +parser.add_argument('--Loadmodel', type=str2bool, + default=False, help='Load pretrained model or Not') +parser.add_argument('--ModelIdex', type=int, default=100, + help='which model to load') + +parser.add_argument('--seed', type=int, default=0, help='random seed') +parser.add_argument('--Max_train_steps', type=int, + default=int(5e8), help='Max training steps') +parser.add_argument('--save_interval', type=int, + default=int(5e5), help='Model saving interval, in steps.') +parser.add_argument('--eval_interval', type=int, default=int(5e3), + help='Model evaluating interval, in steps.') +parser.add_argument('--update_every', type=int, default=50, + help='Training Fraquency, in stpes') + +parser.add_argument('--gamma', type=float, default=0.99, + help='Discounted Factor') +parser.add_argument('--net_width', type=int, default=256, + help='Hidden net width, s_dim-400-300-a_dim') +parser.add_argument('--a_lr', type=float, default=3e-4, + help='Learning rate of actor') +parser.add_argument('--c_lr', type=float, default=3e-4, + help='Learning rate of critic') +parser.add_argument('--batch_size', type=int, default=256, + help='batch_size of training') +parser.add_argument('--alpha', type=float, default=0.12, + help='Entropy coefficient') +parser.add_argument('--adaptive_alpha', type=str2bool, + default=True, help='Use adaptive_alpha or Not') +opt = parser.parse_args() +opt.dvc = torch.device(opt.dvc) # from str to torch.device +print(opt) + + +def main(): + EnvName = ['PartitionMaze_SAC_Continuous', 'Pendulum-v1', 'LunarLanderContinuous-v2', + 'Humanoid-v4', 'HalfCheetah-v4', 'BipedalWalker-v3', 'BipedalWalkerHardcore-v3'] + BrifEnvName = ['PM_SAC_Con', 'PV1', 'LLdV2', + 'Humanv4', 'HCv4', 'BWv3', 'BWHv3'] + + # Build Env + # env = gym.make(EnvName[opt.EnvIdex], + # render_mode="human" if opt.render else None) + # eval_env = gym.make(EnvName[opt.EnvIdex]) + env = PartitionMazeEnv() + eval_env = PartitionMazeEnv() + opt.state_dim = env.observation_space.shape[0] + opt.action_dim = env.action_space.shape[0] + # remark: action space【-max,max】 + opt.max_action = float(env.action_space.high[0]) + opt.max_e_steps = 100 + print(f'Env:{EnvName[opt.EnvIdex]} state_dim:{opt.state_dim} action_dim:{opt.action_dim} ' + f'max_a:{opt.max_action} min_a:{env.action_space.low[0]}') + + # Seed Everything + env_seed = opt.seed + torch.manual_seed(opt.seed) + torch.cuda.manual_seed(opt.seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + print("Random Seed: {}".format(opt.seed)) + + # Build SummaryWriter to record training curves + if opt.write: + from torch.utils.tensorboard import SummaryWriter + timenow = str(datetime.now())[0:-10] + timenow = ' ' + timenow[0:13] + '_' + timenow[-2::] + writepath = 'logs/{}'.format(BrifEnvName[opt.EnvIdex]) + timenow + if os.path.exists(writepath): + shutil.rmtree(writepath) + writer = SummaryWriter(log_dir=writepath) + + # var: transfer argparse to dictionary + agent = SAC_countinuous(**vars(opt)) + if opt.Loadmodel: + agent.load(BrifEnvName[opt.EnvIdex], opt.ModelIdex) + + if opt.render: + while True: + score = evaluate_policy(env, agent, turns=1) + print('EnvName:', BrifEnvName[opt.EnvIdex], 'score:', score) + else: + total_steps = 0 + while total_steps < opt.Max_train_steps: + # Do not use opt.seed directly, or it can overfit to opt.seed + s = env.reset(seed=env_seed) + env_seed += 1 + done = False + + '''Interact & trian''' + while not done: + if total_steps < (5*opt.max_e_steps): + a = env.action_space.sample() # act∈[-max,max] + # a = Action_adapter_reverse(act, opt.max_action) # a∈[-1,1] + else: + a = agent.select_action(s, deterministic=False) # a∈[-1,1] + a = Action_adapter(a, opt.max_action) # act∈[-max,max] + s_next, r, dw, tr, info = env.step( + a) # dw: dead&win; tr: truncated + r = Reward_adapter(r, opt.EnvIdex) + done = (dw or tr) + + agent.replay_buffer.add(s, a, r, s_next, dw) + s = s_next + total_steps += 1 + + '''train if it's time''' + # train 50 times every 50 steps rather than 1 training per step. Better! + if (total_steps >= 2*opt.max_e_steps) and (total_steps % opt.update_every == 0): + for j in range(opt.update_every): + agent.train() + + '''record & log''' + if total_steps % opt.eval_interval == 0: + ep_r = evaluate_policy(eval_env, agent, turns=3) + if opt.write: + writer.add_scalar( + 'ep_r', ep_r, global_step=total_steps) + print( + f'EnvName:{BrifEnvName[opt.EnvIdex]}, Steps: {int(total_steps/1000)}k, Episode Reward:{ep_r}') + + '''save model''' + if total_steps % opt.save_interval == 0: + agent.save(BrifEnvName[opt.EnvIdex], int(total_steps/1000)) + env.close() + eval_env.close() + + +if __name__ == '__main__': + main() diff --git a/SAC_Continuous/utils.py b/SAC_Continuous/utils.py new file mode 100644 index 0000000..cc3cf62 --- /dev/null +++ b/SAC_Continuous/utils.py @@ -0,0 +1,133 @@ +import torch +import argparse +import numpy as np +import torch.nn as nn +import torch.nn.functional as F +from torch.distributions import Normal + + +def build_net(layer_shape, hidden_activation, output_activation): + '''Build net with for loop''' + layers = [] + for j in range(len(layer_shape)-1): + act = hidden_activation if j < len( + layer_shape)-2 else output_activation + layers += [nn.Linear(layer_shape[j], layer_shape[j+1]), act()] + return nn.Sequential(*layers) + + +class Actor(nn.Module): + def __init__(self, state_dim, action_dim, hid_shape, hidden_activation=nn.ReLU, output_activation=nn.ReLU): + super(Actor, self).__init__() + layers = [state_dim] + list(hid_shape) + + self.a_net = build_net(layers, hidden_activation, output_activation) + self.mu_layer = nn.Linear(layers[-1], action_dim) + self.log_std_layer = nn.Linear(layers[-1], action_dim) + + self.LOG_STD_MAX = 2 + self.LOG_STD_MIN = -20 + + def forward(self, state, deterministic, with_logprob): + '''Network with Enforcing Action Bounds''' + net_out = self.a_net(state) + mu = self.mu_layer(net_out) + log_std = self.log_std_layer(net_out) + log_std = torch.clamp(log_std, self.LOG_STD_MIN, + self.LOG_STD_MAX) # 总感觉这里clamp不利于学习 + # we learn log_std rather than std, so that exp(log_std) is always > 0 + std = torch.exp(log_std) + dist = Normal(mu, std) + if deterministic: + u = mu + else: + u = dist.rsample() + + '''↓↓↓ Enforcing Action Bounds, see Page 16 of https://arxiv.org/pdf/1812.05905.pdf ↓↓↓''' + a = torch.tanh(u) + if with_logprob: + # Get probability density of logp_pi_a from probability density of u: + # logp_pi_a = (dist.log_prob(u) - torch.log(1 - a.pow(2) + 1e-6)).sum(dim=1, keepdim=True) + # Derive from the above equation. No a, thus no tanh(h), thus less gradient vanish and more stable. + logp_pi_a = dist.log_prob(u).sum(axis=1, keepdim=True) - ( + 2 * (np.log(2) - u - F.softplus(-2 * u))).sum(axis=1, keepdim=True) + else: + logp_pi_a = None + + return a, logp_pi_a + + +class Double_Q_Critic(nn.Module): + def __init__(self, state_dim, action_dim, hid_shape): + super(Double_Q_Critic, self).__init__() + layers = [state_dim + action_dim] + list(hid_shape) + [1] + + self.Q_1 = build_net(layers, nn.ReLU, nn.Identity) + self.Q_2 = build_net(layers, nn.ReLU, nn.Identity) + + def forward(self, state, action): + sa = torch.cat([state, action], 1) + q1 = self.Q_1(sa) + q2 = self.Q_2(sa) + return q1, q2 + +# reward engineering for better training + + +def Reward_adapter(r, EnvIdex): + # For Pendulum-v0 + if EnvIdex == 0: + r = (r + 8) / 8 + + # For LunarLander + elif EnvIdex == 1: + if r <= -100: + r = -10 + + # For BipedalWalker + elif EnvIdex == 4 or EnvIdex == 5: + if r <= -100: + r = -1 + return r + + +def Action_adapter(a, max_action): + # from [-1,1] to [0,1] + return (a + 1) / 2 + + +def Action_adapter_reverse(act, max_action): + # from [-max,max] to [-1,1] + return act/max_action + + +def evaluate_policy(env, agent, turns=3): + total_scores = 0 + for j in range(turns): + s = env.reset() + done = False + action_series = [] + while not done: + # Take deterministic actions at test time + a = agent.select_action(s, deterministic=True) + a = Action_adapter(a, 1) + s_next, r, dw, tr, info = env.step(a) + done = (dw or tr) + action_series.append(a[0]) + total_scores += r + s = s_next + print('action_series:', np.round(action_series, 3)) + print('state:', s) + return total_scores/turns + + +def str2bool(v): + '''transfer str to bool for argparse''' + if isinstance(v, bool): + return v + if v.lower() in ('yes', 'True', 'true', 'TRUE', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'False', 'false', 'FALSE', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') diff --git a/env.py b/env.py index 6765df9..b78dcfb 100644 --- a/env.py +++ b/env.py @@ -40,7 +40,7 @@ class PartitionMazeEnv(gym.Env): # 可能需要手动修改的超参数 ############################## self.CUT_NUM = 4 # 横切一半,竖切一半 - self.BASE_LINE = 4000 # 基准时间,通过greedy或者蒙特卡洛计算出来 + self.BASE_LINE = 3500 # 基准时间,通过greedy或者蒙特卡洛计算出来 self.MAX_STEPS = 50 # 迷宫走法步数上限 self.phase = 0 # 阶段控制,0:区域划分阶段,1:迷宫初始化阶段,2:走迷宫阶段 @@ -87,11 +87,11 @@ class PartitionMazeEnv(gym.Env): self.car_pos = [(self.H / 2, self.W / 2) for _ in range(self.num_cars)] self.car_traj = [[] for _ in range(self.num_cars)] self.current_car_index = 0 - + # 状态:前 4 维为 partition_values,其余为区域访问状态(初始全0) max_regions = (self.CUT_NUM // 2 + 1) ** 2 state = np.concatenate([ - self.partition_values, + self.partition_values, np.zeros(max_regions, dtype=np.float32) ]) return state @@ -109,7 +109,7 @@ class PartitionMazeEnv(gym.Env): # 构造当前状态:前 partition_step 个为已决策值,其余为 0,再补 7 个 0 state = np.concatenate([ - self.partition_values, + self.partition_values, np.zeros((self.CUT_NUM // 2 + 1) ** 2, dtype=np.float32) ]) @@ -159,21 +159,22 @@ class PartitionMazeEnv(gym.Env): break if not valid_partition: - reward = -10000 + reward = -10 # 状态:前 4 维为 partition_values,其余为区域访问状态(初始全0) max_regions = (self.CUT_NUM // 2 + 1) ** 2 state = np.concatenate([ - self.partition_values, + self.partition_values, np.zeros(max_regions, dtype=np.float32) ]) return state, reward, True, False, {} else: # 进入阶段 1:初始化迷宫 self.phase = 1 - reward = 100 - + reward = 0.2 + # 构建反向索引,方便后续计算 - self.reverse_rectangles = {v['center']: k for k, v in self.rectangles.items()} + self.reverse_rectangles = { + v['center']: k for k, v in self.rectangles.items()} region_centers = [ (i, j, self.rectangles[(i, j)]['center']) @@ -194,19 +195,21 @@ class PartitionMazeEnv(gym.Env): # 进入阶段 2:走迷宫 self.phase = 2 - + # 构造访问状态向量 max_regions = (self.CUT_NUM // 2 + 1) ** 2 visit_status = np.zeros(max_regions, dtype=np.float32) - + # 将实际区域的访问状态填入向量 for i in range(len(self.row_cuts) - 1): for j in range(len(self.col_cuts) - 1): idx = i * (len(self.col_cuts) - 1) + j - visit_status[idx] = float(self.rectangles[(i, j)]['is_visited']) + visit_status[idx] = float( + self.rectangles[(i, j)]['is_visited']) for i in range(idx + 1, max_regions): visit_status[i] = 100 - state = np.concatenate([self.partition_values, visit_status]) + state = np.concatenate( + [self.partition_values, visit_status]) return state, reward, False, False, {} elif self.phase == 2: @@ -214,7 +217,7 @@ class PartitionMazeEnv(gym.Env): current_car = self.current_car_index # 查表,找出当前车辆所在的网格 current_row, current_col = self.reverse_rectangles[self.car_pos[current_car]] - + reward = 0 # 当前动作 a 为 1 维连续动作,映射到四个方向 @@ -235,7 +238,7 @@ class PartitionMazeEnv(gym.Env): if move_dir == 'up': if current_row > 0: new_row = current_row - 1 - else: # 错误的移动给一些惩罚? + else: # 错误的移动给一些惩罚? new_row = current_row # reward -= 10 elif move_dir == 'down': @@ -258,31 +261,43 @@ class PartitionMazeEnv(gym.Env): # reward -= 10 # 如果移动不合法,或者动作为stay,则保持原位置 + # 检查是否移动 + car_moved = (new_row != current_row or new_col != current_col) # 更新车辆位置 self.car_pos[current_car] = self.rectangles[( new_row, new_col)]['center'] - if new_row != current_row or new_col != current_col: + if car_moved: self.car_traj[current_car].append((new_row, new_col)) + # 更新访问标记:将新网格标记为已访问 + self.rectangles[(new_row, new_col)]['is_visited'] = True + + # 记录所有车辆一轮中是否移动 + if self.current_car_index == 0: + # 新一轮的开始,初始化移动标记 + self.cars_moved = [False] * self.num_cars + self.cars_moved[current_car] = car_moved + # 如果一轮结束,检查是否所有车辆都没有移动 + if self.current_car_index == (self.num_cars - 1) and not any(self.cars_moved): + reward -= 0.01 + self.step_count += 1 self.current_car_index = ( self.current_car_index + 1) % self.num_cars - # 更新访问标记:将新网格标记为已访问 - self.rectangles[(new_row, new_col)]['is_visited'] = True - # 观察状态 # 构造访问状态向量 max_regions = (self.CUT_NUM // 2 + 1) ** 2 visit_status = np.zeros(max_regions, dtype=np.float32) - + # 将实际区域的访问状态填入向量 for i in range(len(self.row_cuts) - 1): for j in range(len(self.col_cuts) - 1): idx = i * (len(self.col_cuts) - 1) + j - visit_status[idx] = float(self.rectangles[(i, j)]['is_visited']) + visit_status[idx] = float( + self.rectangles[(i, j)]['is_visited']) for i in range(idx + 1, max_regions): visit_status[i] = 100 - state = np.concatenate([self.partition_values, visit_status]) + state = np.concatenate([self.partition_values, visit_status]) # Episode 终止条件:所有网格均被访问或步数达到上限 done = all([value['is_visited'] for _, value in self.rectangles.items()]) or ( @@ -293,11 +308,7 @@ class PartitionMazeEnv(gym.Env): for idx in range(self.num_cars)]) # TODO 让奖励在baseline附近变化更剧烈 # reward = math.exp(-T / self.BASE_LINE) * 1000 - reward = self.BASE_LINE / T * 1000 - if T < self.BASE_LINE: - reward *= 10 - print(reward) - + reward += self.BASE_LINE / T # if reward > self.BASE_LINE: # reward -= 200 @@ -305,7 +316,7 @@ class PartitionMazeEnv(gym.Env): # reward -= 10 * self.step_count # TODO 动态调整baseline elif done and self.step_count >= self.MAX_STEPS: - reward += -1000 + reward += -0.8 return state, reward, done, False, {}