From 75e523727232381ef071202c36c081caa588c6e9 Mon Sep 17 00:00:00 2001 From: weixin_46229132 Date: Fri, 14 Mar 2025 16:06:59 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9DDPG?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 1 + {DDPG => DDPG_solver}/DDPG.py | 8 ++++---- {DDPG => DDPG_solver}/main.py | 10 +++++----- {DDPG => DDPG_solver}/utils.py | 0 4 files changed, 10 insertions(+), 9 deletions(-) rename {DDPG => DDPG_solver}/DDPG.py (89%) rename {DDPG => DDPG_solver}/main.py (92%) rename {DDPG => DDPG_solver}/utils.py (100%) diff --git a/.gitignore b/.gitignore index 4b3b308..a3041d4 100644 --- a/.gitignore +++ b/.gitignore @@ -12,6 +12,7 @@ weights/ solutions/ PPO_preTrained/ PPO_logs/ +logs/ # Distribution / packaging .Python diff --git a/DDPG/DDPG.py b/DDPG_solver/DDPG.py similarity index 89% rename from DDPG/DDPG.py rename to DDPG_solver/DDPG.py index 0a24aa1..63bc17a 100644 --- a/DDPG/DDPG.py +++ b/DDPG_solver/DDPG.py @@ -65,12 +65,12 @@ class DDPG_agent(): target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) def save(self,EnvName, timestep): - torch.save(self.actor.state_dict(), "./model/{}_actor{}.pth".format(EnvName,timestep)) - torch.save(self.q_critic.state_dict(), "./model/{}_q_critic{}.pth".format(EnvName,timestep)) + torch.save(self.actor.state_dict(), "./weights/{}_actor{}.pth".format(EnvName,timestep)) + torch.save(self.q_critic.state_dict(), "./weights/{}_q_critic{}.pth".format(EnvName,timestep)) def load(self,EnvName, timestep): - self.actor.load_state_dict(torch.load("./model/{}_actor{}.pth".format(EnvName, timestep), map_location=self.dvc)) - self.q_critic.load_state_dict(torch.load("./model/{}_q_critic{}.pth".format(EnvName, timestep), map_location=self.dvc)) + self.actor.load_state_dict(torch.load("./weights/{}_actor{}.pth".format(EnvName, timestep), map_location=self.dvc)) + self.q_critic.load_state_dict(torch.load("./weights/{}_q_critic{}.pth".format(EnvName, timestep), map_location=self.dvc)) class ReplayBuffer(): diff --git a/DDPG/main.py b/DDPG_solver/main.py similarity index 92% rename from DDPG/main.py rename to DDPG_solver/main.py index 985b361..daa2c2c 100644 --- a/DDPG/main.py +++ b/DDPG_solver/main.py @@ -13,14 +13,14 @@ from env import PartitionMazeEnv '''Hyperparameter Setting''' parser = argparse.ArgumentParser() parser.add_argument('--dvc', type=str, default='cpu', help='running device: cuda or cpu') -parser.add_argument('--EnvIdex', type=int, default=0, help='PV1, Lch_Cv2, Humanv4, HCv4, BWv3, BWHv3') +parser.add_argument('--EnvIdex', type=int, default=0, help='PartitionMaze_DDPG, PV1, Lch_Cv2, Humanv4, HCv4, BWv3, BWHv3') parser.add_argument('--write', type=str2bool, default=False, help='Use SummaryWriter to record the training') parser.add_argument('--render', type=str2bool, default=False, help='Render or Not') parser.add_argument('--Loadmodel', type=str2bool, default=False, help='Load pretrained model or Not') parser.add_argument('--ModelIdex', type=int, default=100, help='which model to load') -parser.add_argument('--seed', type=int, default=0, help='random seed') -parser.add_argument('--Max_train_steps', type=int, default=5e6, help='Max training steps') +parser.add_argument('--seed', type=int, default=42, help='random seed') +parser.add_argument('--Max_train_steps', type=int, default=5e8, help='Max training steps') parser.add_argument('--save_interval', type=int, default=1e5, help='Model saving interval, in steps.') parser.add_argument('--eval_interval', type=int, default=2e3, help='Model evaluating interval, in steps.') @@ -64,13 +64,13 @@ def main(): from torch.utils.tensorboard import SummaryWriter timenow = str(datetime.now())[0:-10] timenow = ' ' + timenow[0:13] + '_' + timenow[-2::] - writepath = 'runs/{}'.format(BrifEnvName[opt.EnvIdex]) + timenow + writepath = 'logs/{}'.format(BrifEnvName[opt.EnvIdex]) + timenow if os.path.exists(writepath): shutil.rmtree(writepath) writer = SummaryWriter(log_dir=writepath) # Build DRL model - if not os.path.exists('model'): os.mkdir('model') + if not os.path.exists('weights'): os.mkdir('weights') agent = DDPG_agent(**vars(opt)) # var: transfer argparse to dictionary if opt.Loadmodel: agent.load(BrifEnvName[opt.EnvIdex], opt.ModelIdex) diff --git a/DDPG/utils.py b/DDPG_solver/utils.py similarity index 100% rename from DDPG/utils.py rename to DDPG_solver/utils.py