修ppo bug
This commit is contained in:
parent
6dc285d3f8
commit
d364a1e4df
@ -1,14 +1,16 @@
|
||||
from DDPG import DDPG_agent
|
||||
from datetime import datetime
|
||||
from utils import str2bool, evaluate_policy
|
||||
import gymnasium as gym
|
||||
import os
|
||||
import shutil
|
||||
import argparse
|
||||
import torch
|
||||
# fmt: off
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from env import PartitionMazeEnv
|
||||
from utils import str2bool, evaluate_policy
|
||||
from datetime import datetime
|
||||
from DDPG import DDPG_agent
|
||||
# fmt: on
|
||||
|
||||
'''Hyperparameter Setting'''
|
||||
parser = argparse.ArgumentParser()
|
||||
|
@ -133,8 +133,8 @@ class PPO_agent(object):
|
||||
self.dw_hoder[idx] = dw
|
||||
|
||||
def save(self,EnvName, timestep):
|
||||
torch.save(self.actor.state_dict(), "./model/{}_actor{}.pth".format(EnvName,timestep))
|
||||
torch.save(self.critic.state_dict(), "./model/{}_q_critic{}.pth".format(EnvName,timestep))
|
||||
torch.save(self.actor.state_dict(), "./weights/{}_actor{}.pth".format(EnvName,timestep))
|
||||
torch.save(self.critic.state_dict(), "./weights/{}_q_critic{}.pth".format(EnvName,timestep))
|
||||
|
||||
def load(self,EnvName, timestep):
|
||||
self.actor.load_state_dict(torch.load("./model/{}_actor{}.pth".format(EnvName, timestep), map_location=self.dvc))
|
||||
|
@ -1,18 +1,20 @@
|
||||
from datetime import datetime
|
||||
import os
|
||||
import shutil
|
||||
import argparse
|
||||
import torch
|
||||
import gymnasium as gym
|
||||
from utils import str2bool, Action_adapter, Reward_adapter, evaluate_policy
|
||||
from PPO import PPO_agent
|
||||
# fmt: off
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from env import PartitionMazeEnv
|
||||
# fmt: on
|
||||
|
||||
'''Hyperparameter Setting'''
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--dvc', type=str, default='',
|
||||
parser.add_argument('--dvc', type=str, default='cpu',
|
||||
help='running device: cuda or cpu')
|
||||
parser.add_argument('--EnvIdex', type=int, default=0,
|
||||
help='PV1, Lch_Cv2, Humanv4, HCv4, BWv3, BWHv3')
|
||||
@ -79,9 +81,8 @@ def main():
|
||||
opt.state_dim = env.observation_space.shape[0]
|
||||
opt.action_dim = env.action_space.shape[0]
|
||||
opt.max_action = float(env.action_space.high[0])
|
||||
opt.max_steps = env._max_episode_steps
|
||||
print('Env:', EnvName[opt.EnvIdex], ' state_dim:', opt.state_dim, ' action_dim:', opt.action_dim,
|
||||
' max_a:', opt.max_action, ' min_a:', env.action_space.low[0], 'max_steps', opt.max_steps)
|
||||
' max_a:', opt.max_action, ' min_a:', env.action_space.low[0])
|
||||
|
||||
# Seed Everything
|
||||
env_seed = opt.seed
|
||||
@ -121,7 +122,7 @@ def main():
|
||||
traj_lenth, total_steps = 0, 0
|
||||
while total_steps < opt.Max_train_steps:
|
||||
# Do not use opt.seed directly, or it can overfit to opt.seed
|
||||
s, info = env.reset(seed=env_seed)
|
||||
s = env.reset(seed=env_seed)
|
||||
env_seed += 1
|
||||
done = False
|
||||
|
||||
|
@ -137,7 +137,7 @@ def Reward_adapter(r, EnvIdex):
|
||||
def evaluate_policy(env, agent, max_action, turns):
|
||||
total_scores = 0
|
||||
for j in range(turns):
|
||||
s, info = env.reset()
|
||||
s = env.reset()
|
||||
done = False
|
||||
action_series = []
|
||||
while not done:
|
||||
|
Loading…
Reference in New Issue
Block a user