From d364a1e4df7a351759c2c987120a92c67e42c1ba Mon Sep 17 00:00:00 2001 From: weixin_46229132 Date: Wed, 19 Mar 2025 15:23:55 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AEppo=20bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- DDPG_solver/main.py | 10 ++++++---- PPO_Continuous/PPO.py | 4 ++-- PPO_Continuous/main.py | 11 ++++++----- PPO_Continuous/utils.py | 2 +- 4 files changed, 15 insertions(+), 12 deletions(-) diff --git a/DDPG_solver/main.py b/DDPG_solver/main.py index 86bccb2..c1fdd6e 100644 --- a/DDPG_solver/main.py +++ b/DDPG_solver/main.py @@ -1,14 +1,16 @@ +from DDPG import DDPG_agent +from datetime import datetime +from utils import str2bool, evaluate_policy import gymnasium as gym -import os import shutil import argparse import torch +# fmt: off import sys +import os sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from env import PartitionMazeEnv -from utils import str2bool, evaluate_policy -from datetime import datetime -from DDPG import DDPG_agent +# fmt: on '''Hyperparameter Setting''' parser = argparse.ArgumentParser() diff --git a/PPO_Continuous/PPO.py b/PPO_Continuous/PPO.py index 5715854..276a1bf 100644 --- a/PPO_Continuous/PPO.py +++ b/PPO_Continuous/PPO.py @@ -133,8 +133,8 @@ class PPO_agent(object): self.dw_hoder[idx] = dw def save(self,EnvName, timestep): - torch.save(self.actor.state_dict(), "./model/{}_actor{}.pth".format(EnvName,timestep)) - torch.save(self.critic.state_dict(), "./model/{}_q_critic{}.pth".format(EnvName,timestep)) + torch.save(self.actor.state_dict(), "./weights/{}_actor{}.pth".format(EnvName,timestep)) + torch.save(self.critic.state_dict(), "./weights/{}_q_critic{}.pth".format(EnvName,timestep)) def load(self,EnvName, timestep): self.actor.load_state_dict(torch.load("./model/{}_actor{}.pth".format(EnvName, timestep), map_location=self.dvc)) diff --git a/PPO_Continuous/main.py b/PPO_Continuous/main.py index d6c2f4e..8aded0a 100644 --- a/PPO_Continuous/main.py +++ b/PPO_Continuous/main.py @@ -1,18 +1,20 @@ from datetime import datetime -import os import shutil import argparse import torch import gymnasium as gym from utils import str2bool, Action_adapter, Reward_adapter, evaluate_policy from PPO import PPO_agent +# fmt: off import sys +import os sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from env import PartitionMazeEnv +# fmt: on '''Hyperparameter Setting''' parser = argparse.ArgumentParser() -parser.add_argument('--dvc', type=str, default='', +parser.add_argument('--dvc', type=str, default='cpu', help='running device: cuda or cpu') parser.add_argument('--EnvIdex', type=int, default=0, help='PV1, Lch_Cv2, Humanv4, HCv4, BWv3, BWHv3') @@ -79,9 +81,8 @@ def main(): opt.state_dim = env.observation_space.shape[0] opt.action_dim = env.action_space.shape[0] opt.max_action = float(env.action_space.high[0]) - opt.max_steps = env._max_episode_steps print('Env:', EnvName[opt.EnvIdex], ' state_dim:', opt.state_dim, ' action_dim:', opt.action_dim, - ' max_a:', opt.max_action, ' min_a:', env.action_space.low[0], 'max_steps', opt.max_steps) + ' max_a:', opt.max_action, ' min_a:', env.action_space.low[0]) # Seed Everything env_seed = opt.seed @@ -121,7 +122,7 @@ def main(): traj_lenth, total_steps = 0, 0 while total_steps < opt.Max_train_steps: # Do not use opt.seed directly, or it can overfit to opt.seed - s, info = env.reset(seed=env_seed) + s = env.reset(seed=env_seed) env_seed += 1 done = False diff --git a/PPO_Continuous/utils.py b/PPO_Continuous/utils.py index 7d20a0f..58b339d 100644 --- a/PPO_Continuous/utils.py +++ b/PPO_Continuous/utils.py @@ -137,7 +137,7 @@ def Reward_adapter(r, EnvIdex): def evaluate_policy(env, agent, max_action, turns): total_scores = 0 for j in range(turns): - s, info = env.reset() + s = env.reset() done = False action_series = [] while not done: