修改DDPG

This commit is contained in:
weixin_46229132 2025-03-14 16:06:59 +08:00
parent ab51727253
commit 75e5237272
4 changed files with 10 additions and 9 deletions

1
.gitignore vendored
View File

@ -12,6 +12,7 @@ weights/
solutions/
PPO_preTrained/
PPO_logs/
logs/
# Distribution / packaging
.Python

View File

@ -65,12 +65,12 @@ class DDPG_agent():
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
def save(self,EnvName, timestep):
torch.save(self.actor.state_dict(), "./model/{}_actor{}.pth".format(EnvName,timestep))
torch.save(self.q_critic.state_dict(), "./model/{}_q_critic{}.pth".format(EnvName,timestep))
torch.save(self.actor.state_dict(), "./weights/{}_actor{}.pth".format(EnvName,timestep))
torch.save(self.q_critic.state_dict(), "./weights/{}_q_critic{}.pth".format(EnvName,timestep))
def load(self,EnvName, timestep):
self.actor.load_state_dict(torch.load("./model/{}_actor{}.pth".format(EnvName, timestep), map_location=self.dvc))
self.q_critic.load_state_dict(torch.load("./model/{}_q_critic{}.pth".format(EnvName, timestep), map_location=self.dvc))
self.actor.load_state_dict(torch.load("./weights/{}_actor{}.pth".format(EnvName, timestep), map_location=self.dvc))
self.q_critic.load_state_dict(torch.load("./weights/{}_q_critic{}.pth".format(EnvName, timestep), map_location=self.dvc))
class ReplayBuffer():

View File

@ -13,14 +13,14 @@ from env import PartitionMazeEnv
'''Hyperparameter Setting'''
parser = argparse.ArgumentParser()
parser.add_argument('--dvc', type=str, default='cpu', help='running device: cuda or cpu')
parser.add_argument('--EnvIdex', type=int, default=0, help='PV1, Lch_Cv2, Humanv4, HCv4, BWv3, BWHv3')
parser.add_argument('--EnvIdex', type=int, default=0, help='PartitionMaze_DDPG, PV1, Lch_Cv2, Humanv4, HCv4, BWv3, BWHv3')
parser.add_argument('--write', type=str2bool, default=False, help='Use SummaryWriter to record the training')
parser.add_argument('--render', type=str2bool, default=False, help='Render or Not')
parser.add_argument('--Loadmodel', type=str2bool, default=False, help='Load pretrained model or Not')
parser.add_argument('--ModelIdex', type=int, default=100, help='which model to load')
parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument('--Max_train_steps', type=int, default=5e6, help='Max training steps')
parser.add_argument('--seed', type=int, default=42, help='random seed')
parser.add_argument('--Max_train_steps', type=int, default=5e8, help='Max training steps')
parser.add_argument('--save_interval', type=int, default=1e5, help='Model saving interval, in steps.')
parser.add_argument('--eval_interval', type=int, default=2e3, help='Model evaluating interval, in steps.')
@ -64,13 +64,13 @@ def main():
from torch.utils.tensorboard import SummaryWriter
timenow = str(datetime.now())[0:-10]
timenow = ' ' + timenow[0:13] + '_' + timenow[-2::]
writepath = 'runs/{}'.format(BrifEnvName[opt.EnvIdex]) + timenow
writepath = 'logs/{}'.format(BrifEnvName[opt.EnvIdex]) + timenow
if os.path.exists(writepath): shutil.rmtree(writepath)
writer = SummaryWriter(log_dir=writepath)
# Build DRL model
if not os.path.exists('model'): os.mkdir('model')
if not os.path.exists('weights'): os.mkdir('weights')
agent = DDPG_agent(**vars(opt)) # var: transfer argparse to dictionary
if opt.Loadmodel: agent.load(BrifEnvName[opt.EnvIdex], opt.ModelIdex)