修改DDPG
This commit is contained in:
parent
ab51727253
commit
75e5237272
1
.gitignore
vendored
1
.gitignore
vendored
@ -12,6 +12,7 @@ weights/
|
||||
solutions/
|
||||
PPO_preTrained/
|
||||
PPO_logs/
|
||||
logs/
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
|
@ -65,12 +65,12 @@ class DDPG_agent():
|
||||
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
|
||||
|
||||
def save(self,EnvName, timestep):
|
||||
torch.save(self.actor.state_dict(), "./model/{}_actor{}.pth".format(EnvName,timestep))
|
||||
torch.save(self.q_critic.state_dict(), "./model/{}_q_critic{}.pth".format(EnvName,timestep))
|
||||
torch.save(self.actor.state_dict(), "./weights/{}_actor{}.pth".format(EnvName,timestep))
|
||||
torch.save(self.q_critic.state_dict(), "./weights/{}_q_critic{}.pth".format(EnvName,timestep))
|
||||
|
||||
def load(self,EnvName, timestep):
|
||||
self.actor.load_state_dict(torch.load("./model/{}_actor{}.pth".format(EnvName, timestep), map_location=self.dvc))
|
||||
self.q_critic.load_state_dict(torch.load("./model/{}_q_critic{}.pth".format(EnvName, timestep), map_location=self.dvc))
|
||||
self.actor.load_state_dict(torch.load("./weights/{}_actor{}.pth".format(EnvName, timestep), map_location=self.dvc))
|
||||
self.q_critic.load_state_dict(torch.load("./weights/{}_q_critic{}.pth".format(EnvName, timestep), map_location=self.dvc))
|
||||
|
||||
|
||||
class ReplayBuffer():
|
@ -13,14 +13,14 @@ from env import PartitionMazeEnv
|
||||
'''Hyperparameter Setting'''
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--dvc', type=str, default='cpu', help='running device: cuda or cpu')
|
||||
parser.add_argument('--EnvIdex', type=int, default=0, help='PV1, Lch_Cv2, Humanv4, HCv4, BWv3, BWHv3')
|
||||
parser.add_argument('--EnvIdex', type=int, default=0, help='PartitionMaze_DDPG, PV1, Lch_Cv2, Humanv4, HCv4, BWv3, BWHv3')
|
||||
parser.add_argument('--write', type=str2bool, default=False, help='Use SummaryWriter to record the training')
|
||||
parser.add_argument('--render', type=str2bool, default=False, help='Render or Not')
|
||||
parser.add_argument('--Loadmodel', type=str2bool, default=False, help='Load pretrained model or Not')
|
||||
parser.add_argument('--ModelIdex', type=int, default=100, help='which model to load')
|
||||
|
||||
parser.add_argument('--seed', type=int, default=0, help='random seed')
|
||||
parser.add_argument('--Max_train_steps', type=int, default=5e6, help='Max training steps')
|
||||
parser.add_argument('--seed', type=int, default=42, help='random seed')
|
||||
parser.add_argument('--Max_train_steps', type=int, default=5e8, help='Max training steps')
|
||||
parser.add_argument('--save_interval', type=int, default=1e5, help='Model saving interval, in steps.')
|
||||
parser.add_argument('--eval_interval', type=int, default=2e3, help='Model evaluating interval, in steps.')
|
||||
|
||||
@ -64,13 +64,13 @@ def main():
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
timenow = str(datetime.now())[0:-10]
|
||||
timenow = ' ' + timenow[0:13] + '_' + timenow[-2::]
|
||||
writepath = 'runs/{}'.format(BrifEnvName[opt.EnvIdex]) + timenow
|
||||
writepath = 'logs/{}'.format(BrifEnvName[opt.EnvIdex]) + timenow
|
||||
if os.path.exists(writepath): shutil.rmtree(writepath)
|
||||
writer = SummaryWriter(log_dir=writepath)
|
||||
|
||||
|
||||
# Build DRL model
|
||||
if not os.path.exists('model'): os.mkdir('model')
|
||||
if not os.path.exists('weights'): os.mkdir('weights')
|
||||
agent = DDPG_agent(**vars(opt)) # var: transfer argparse to dictionary
|
||||
if opt.Loadmodel: agent.load(BrifEnvName[opt.EnvIdex], opt.ModelIdex)
|
||||
|
Loading…
Reference in New Issue
Block a user