修ppo bug

This commit is contained in:
weixin_46229132 2025-03-19 15:23:55 +08:00
parent 6dc285d3f8
commit d364a1e4df
4 changed files with 15 additions and 12 deletions

View File

@ -1,14 +1,16 @@
from DDPG import DDPG_agent
from datetime import datetime
from utils import str2bool, evaluate_policy
import gymnasium as gym import gymnasium as gym
import os
import shutil import shutil
import argparse import argparse
import torch import torch
# fmt: off
import sys import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from env import PartitionMazeEnv from env import PartitionMazeEnv
from utils import str2bool, evaluate_policy # fmt: on
from datetime import datetime
from DDPG import DDPG_agent
'''Hyperparameter Setting''' '''Hyperparameter Setting'''
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()

View File

@ -133,8 +133,8 @@ class PPO_agent(object):
self.dw_hoder[idx] = dw self.dw_hoder[idx] = dw
def save(self,EnvName, timestep): def save(self,EnvName, timestep):
torch.save(self.actor.state_dict(), "./model/{}_actor{}.pth".format(EnvName,timestep)) torch.save(self.actor.state_dict(), "./weights/{}_actor{}.pth".format(EnvName,timestep))
torch.save(self.critic.state_dict(), "./model/{}_q_critic{}.pth".format(EnvName,timestep)) torch.save(self.critic.state_dict(), "./weights/{}_q_critic{}.pth".format(EnvName,timestep))
def load(self,EnvName, timestep): def load(self,EnvName, timestep):
self.actor.load_state_dict(torch.load("./model/{}_actor{}.pth".format(EnvName, timestep), map_location=self.dvc)) self.actor.load_state_dict(torch.load("./model/{}_actor{}.pth".format(EnvName, timestep), map_location=self.dvc))

View File

@ -1,18 +1,20 @@
from datetime import datetime from datetime import datetime
import os
import shutil import shutil
import argparse import argparse
import torch import torch
import gymnasium as gym import gymnasium as gym
from utils import str2bool, Action_adapter, Reward_adapter, evaluate_policy from utils import str2bool, Action_adapter, Reward_adapter, evaluate_policy
from PPO import PPO_agent from PPO import PPO_agent
# fmt: off
import sys import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from env import PartitionMazeEnv from env import PartitionMazeEnv
# fmt: on
'''Hyperparameter Setting''' '''Hyperparameter Setting'''
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--dvc', type=str, default='', parser.add_argument('--dvc', type=str, default='cpu',
help='running device: cuda or cpu') help='running device: cuda or cpu')
parser.add_argument('--EnvIdex', type=int, default=0, parser.add_argument('--EnvIdex', type=int, default=0,
help='PV1, Lch_Cv2, Humanv4, HCv4, BWv3, BWHv3') help='PV1, Lch_Cv2, Humanv4, HCv4, BWv3, BWHv3')
@ -79,9 +81,8 @@ def main():
opt.state_dim = env.observation_space.shape[0] opt.state_dim = env.observation_space.shape[0]
opt.action_dim = env.action_space.shape[0] opt.action_dim = env.action_space.shape[0]
opt.max_action = float(env.action_space.high[0]) opt.max_action = float(env.action_space.high[0])
opt.max_steps = env._max_episode_steps
print('Env:', EnvName[opt.EnvIdex], ' state_dim:', opt.state_dim, ' action_dim:', opt.action_dim, print('Env:', EnvName[opt.EnvIdex], ' state_dim:', opt.state_dim, ' action_dim:', opt.action_dim,
' max_a:', opt.max_action, ' min_a:', env.action_space.low[0], 'max_steps', opt.max_steps) ' max_a:', opt.max_action, ' min_a:', env.action_space.low[0])
# Seed Everything # Seed Everything
env_seed = opt.seed env_seed = opt.seed
@ -121,7 +122,7 @@ def main():
traj_lenth, total_steps = 0, 0 traj_lenth, total_steps = 0, 0
while total_steps < opt.Max_train_steps: while total_steps < opt.Max_train_steps:
# Do not use opt.seed directly, or it can overfit to opt.seed # Do not use opt.seed directly, or it can overfit to opt.seed
s, info = env.reset(seed=env_seed) s = env.reset(seed=env_seed)
env_seed += 1 env_seed += 1
done = False done = False

View File

@ -137,7 +137,7 @@ def Reward_adapter(r, EnvIdex):
def evaluate_policy(env, agent, max_action, turns): def evaluate_policy(env, agent, max_action, turns):
total_scores = 0 total_scores = 0
for j in range(turns): for j in range(turns):
s, info = env.reset() s = env.reset()
done = False done = False
action_series = [] action_series = []
while not done: while not done: