修改DDPG
This commit is contained in:
parent
ab51727253
commit
75e5237272
1
.gitignore
vendored
1
.gitignore
vendored
@ -12,6 +12,7 @@ weights/
|
|||||||
solutions/
|
solutions/
|
||||||
PPO_preTrained/
|
PPO_preTrained/
|
||||||
PPO_logs/
|
PPO_logs/
|
||||||
|
logs/
|
||||||
|
|
||||||
# Distribution / packaging
|
# Distribution / packaging
|
||||||
.Python
|
.Python
|
||||||
|
@ -65,12 +65,12 @@ class DDPG_agent():
|
|||||||
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
|
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
|
||||||
|
|
||||||
def save(self,EnvName, timestep):
|
def save(self,EnvName, timestep):
|
||||||
torch.save(self.actor.state_dict(), "./model/{}_actor{}.pth".format(EnvName,timestep))
|
torch.save(self.actor.state_dict(), "./weights/{}_actor{}.pth".format(EnvName,timestep))
|
||||||
torch.save(self.q_critic.state_dict(), "./model/{}_q_critic{}.pth".format(EnvName,timestep))
|
torch.save(self.q_critic.state_dict(), "./weights/{}_q_critic{}.pth".format(EnvName,timestep))
|
||||||
|
|
||||||
def load(self,EnvName, timestep):
|
def load(self,EnvName, timestep):
|
||||||
self.actor.load_state_dict(torch.load("./model/{}_actor{}.pth".format(EnvName, timestep), map_location=self.dvc))
|
self.actor.load_state_dict(torch.load("./weights/{}_actor{}.pth".format(EnvName, timestep), map_location=self.dvc))
|
||||||
self.q_critic.load_state_dict(torch.load("./model/{}_q_critic{}.pth".format(EnvName, timestep), map_location=self.dvc))
|
self.q_critic.load_state_dict(torch.load("./weights/{}_q_critic{}.pth".format(EnvName, timestep), map_location=self.dvc))
|
||||||
|
|
||||||
|
|
||||||
class ReplayBuffer():
|
class ReplayBuffer():
|
@ -13,14 +13,14 @@ from env import PartitionMazeEnv
|
|||||||
'''Hyperparameter Setting'''
|
'''Hyperparameter Setting'''
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--dvc', type=str, default='cpu', help='running device: cuda or cpu')
|
parser.add_argument('--dvc', type=str, default='cpu', help='running device: cuda or cpu')
|
||||||
parser.add_argument('--EnvIdex', type=int, default=0, help='PV1, Lch_Cv2, Humanv4, HCv4, BWv3, BWHv3')
|
parser.add_argument('--EnvIdex', type=int, default=0, help='PartitionMaze_DDPG, PV1, Lch_Cv2, Humanv4, HCv4, BWv3, BWHv3')
|
||||||
parser.add_argument('--write', type=str2bool, default=False, help='Use SummaryWriter to record the training')
|
parser.add_argument('--write', type=str2bool, default=False, help='Use SummaryWriter to record the training')
|
||||||
parser.add_argument('--render', type=str2bool, default=False, help='Render or Not')
|
parser.add_argument('--render', type=str2bool, default=False, help='Render or Not')
|
||||||
parser.add_argument('--Loadmodel', type=str2bool, default=False, help='Load pretrained model or Not')
|
parser.add_argument('--Loadmodel', type=str2bool, default=False, help='Load pretrained model or Not')
|
||||||
parser.add_argument('--ModelIdex', type=int, default=100, help='which model to load')
|
parser.add_argument('--ModelIdex', type=int, default=100, help='which model to load')
|
||||||
|
|
||||||
parser.add_argument('--seed', type=int, default=0, help='random seed')
|
parser.add_argument('--seed', type=int, default=42, help='random seed')
|
||||||
parser.add_argument('--Max_train_steps', type=int, default=5e6, help='Max training steps')
|
parser.add_argument('--Max_train_steps', type=int, default=5e8, help='Max training steps')
|
||||||
parser.add_argument('--save_interval', type=int, default=1e5, help='Model saving interval, in steps.')
|
parser.add_argument('--save_interval', type=int, default=1e5, help='Model saving interval, in steps.')
|
||||||
parser.add_argument('--eval_interval', type=int, default=2e3, help='Model evaluating interval, in steps.')
|
parser.add_argument('--eval_interval', type=int, default=2e3, help='Model evaluating interval, in steps.')
|
||||||
|
|
||||||
@ -64,13 +64,13 @@ def main():
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
timenow = str(datetime.now())[0:-10]
|
timenow = str(datetime.now())[0:-10]
|
||||||
timenow = ' ' + timenow[0:13] + '_' + timenow[-2::]
|
timenow = ' ' + timenow[0:13] + '_' + timenow[-2::]
|
||||||
writepath = 'runs/{}'.format(BrifEnvName[opt.EnvIdex]) + timenow
|
writepath = 'logs/{}'.format(BrifEnvName[opt.EnvIdex]) + timenow
|
||||||
if os.path.exists(writepath): shutil.rmtree(writepath)
|
if os.path.exists(writepath): shutil.rmtree(writepath)
|
||||||
writer = SummaryWriter(log_dir=writepath)
|
writer = SummaryWriter(log_dir=writepath)
|
||||||
|
|
||||||
|
|
||||||
# Build DRL model
|
# Build DRL model
|
||||||
if not os.path.exists('model'): os.mkdir('model')
|
if not os.path.exists('weights'): os.mkdir('weights')
|
||||||
agent = DDPG_agent(**vars(opt)) # var: transfer argparse to dictionary
|
agent = DDPG_agent(**vars(opt)) # var: transfer argparse to dictionary
|
||||||
if opt.Loadmodel: agent.load(BrifEnvName[opt.EnvIdex], opt.ModelIdex)
|
if opt.Loadmodel: agent.load(BrifEnvName[opt.EnvIdex], opt.ModelIdex)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user