修ppo bug
This commit is contained in:
parent
6dc285d3f8
commit
d364a1e4df
@ -1,14 +1,16 @@
|
|||||||
|
from DDPG import DDPG_agent
|
||||||
|
from datetime import datetime
|
||||||
|
from utils import str2bool, evaluate_policy
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
import os
|
|
||||||
import shutil
|
import shutil
|
||||||
import argparse
|
import argparse
|
||||||
import torch
|
import torch
|
||||||
|
# fmt: off
|
||||||
import sys
|
import sys
|
||||||
|
import os
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
from env import PartitionMazeEnv
|
from env import PartitionMazeEnv
|
||||||
from utils import str2bool, evaluate_policy
|
# fmt: on
|
||||||
from datetime import datetime
|
|
||||||
from DDPG import DDPG_agent
|
|
||||||
|
|
||||||
'''Hyperparameter Setting'''
|
'''Hyperparameter Setting'''
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
@ -133,8 +133,8 @@ class PPO_agent(object):
|
|||||||
self.dw_hoder[idx] = dw
|
self.dw_hoder[idx] = dw
|
||||||
|
|
||||||
def save(self,EnvName, timestep):
|
def save(self,EnvName, timestep):
|
||||||
torch.save(self.actor.state_dict(), "./model/{}_actor{}.pth".format(EnvName,timestep))
|
torch.save(self.actor.state_dict(), "./weights/{}_actor{}.pth".format(EnvName,timestep))
|
||||||
torch.save(self.critic.state_dict(), "./model/{}_q_critic{}.pth".format(EnvName,timestep))
|
torch.save(self.critic.state_dict(), "./weights/{}_q_critic{}.pth".format(EnvName,timestep))
|
||||||
|
|
||||||
def load(self,EnvName, timestep):
|
def load(self,EnvName, timestep):
|
||||||
self.actor.load_state_dict(torch.load("./model/{}_actor{}.pth".format(EnvName, timestep), map_location=self.dvc))
|
self.actor.load_state_dict(torch.load("./model/{}_actor{}.pth".format(EnvName, timestep), map_location=self.dvc))
|
||||||
|
@ -1,18 +1,20 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import os
|
|
||||||
import shutil
|
import shutil
|
||||||
import argparse
|
import argparse
|
||||||
import torch
|
import torch
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
from utils import str2bool, Action_adapter, Reward_adapter, evaluate_policy
|
from utils import str2bool, Action_adapter, Reward_adapter, evaluate_policy
|
||||||
from PPO import PPO_agent
|
from PPO import PPO_agent
|
||||||
|
# fmt: off
|
||||||
import sys
|
import sys
|
||||||
|
import os
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
from env import PartitionMazeEnv
|
from env import PartitionMazeEnv
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
'''Hyperparameter Setting'''
|
'''Hyperparameter Setting'''
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--dvc', type=str, default='',
|
parser.add_argument('--dvc', type=str, default='cpu',
|
||||||
help='running device: cuda or cpu')
|
help='running device: cuda or cpu')
|
||||||
parser.add_argument('--EnvIdex', type=int, default=0,
|
parser.add_argument('--EnvIdex', type=int, default=0,
|
||||||
help='PV1, Lch_Cv2, Humanv4, HCv4, BWv3, BWHv3')
|
help='PV1, Lch_Cv2, Humanv4, HCv4, BWv3, BWHv3')
|
||||||
@ -79,9 +81,8 @@ def main():
|
|||||||
opt.state_dim = env.observation_space.shape[0]
|
opt.state_dim = env.observation_space.shape[0]
|
||||||
opt.action_dim = env.action_space.shape[0]
|
opt.action_dim = env.action_space.shape[0]
|
||||||
opt.max_action = float(env.action_space.high[0])
|
opt.max_action = float(env.action_space.high[0])
|
||||||
opt.max_steps = env._max_episode_steps
|
|
||||||
print('Env:', EnvName[opt.EnvIdex], ' state_dim:', opt.state_dim, ' action_dim:', opt.action_dim,
|
print('Env:', EnvName[opt.EnvIdex], ' state_dim:', opt.state_dim, ' action_dim:', opt.action_dim,
|
||||||
' max_a:', opt.max_action, ' min_a:', env.action_space.low[0], 'max_steps', opt.max_steps)
|
' max_a:', opt.max_action, ' min_a:', env.action_space.low[0])
|
||||||
|
|
||||||
# Seed Everything
|
# Seed Everything
|
||||||
env_seed = opt.seed
|
env_seed = opt.seed
|
||||||
@ -121,7 +122,7 @@ def main():
|
|||||||
traj_lenth, total_steps = 0, 0
|
traj_lenth, total_steps = 0, 0
|
||||||
while total_steps < opt.Max_train_steps:
|
while total_steps < opt.Max_train_steps:
|
||||||
# Do not use opt.seed directly, or it can overfit to opt.seed
|
# Do not use opt.seed directly, or it can overfit to opt.seed
|
||||||
s, info = env.reset(seed=env_seed)
|
s = env.reset(seed=env_seed)
|
||||||
env_seed += 1
|
env_seed += 1
|
||||||
done = False
|
done = False
|
||||||
|
|
||||||
|
@ -137,7 +137,7 @@ def Reward_adapter(r, EnvIdex):
|
|||||||
def evaluate_policy(env, agent, max_action, turns):
|
def evaluate_policy(env, agent, max_action, turns):
|
||||||
total_scores = 0
|
total_scores = 0
|
||||||
for j in range(turns):
|
for j in range(turns):
|
||||||
s, info = env.reset()
|
s = env.reset()
|
||||||
done = False
|
done = False
|
||||||
action_series = []
|
action_series = []
|
||||||
while not done:
|
while not done:
|
||||||
|
Loading…
Reference in New Issue
Block a user