修ppo bug

This commit is contained in:
weixin_46229132 2025-03-19 15:23:55 +08:00
parent 6dc285d3f8
commit d364a1e4df
4 changed files with 15 additions and 12 deletions

View File

@ -1,14 +1,16 @@
from DDPG import DDPG_agent
from datetime import datetime
from utils import str2bool, evaluate_policy
import gymnasium as gym
import os
import shutil
import argparse
import torch
# fmt: off
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from env import PartitionMazeEnv
from utils import str2bool, evaluate_policy
from datetime import datetime
from DDPG import DDPG_agent
# fmt: on
'''Hyperparameter Setting'''
parser = argparse.ArgumentParser()

View File

@ -133,8 +133,8 @@ class PPO_agent(object):
self.dw_hoder[idx] = dw
def save(self,EnvName, timestep):
torch.save(self.actor.state_dict(), "./model/{}_actor{}.pth".format(EnvName,timestep))
torch.save(self.critic.state_dict(), "./model/{}_q_critic{}.pth".format(EnvName,timestep))
torch.save(self.actor.state_dict(), "./weights/{}_actor{}.pth".format(EnvName,timestep))
torch.save(self.critic.state_dict(), "./weights/{}_q_critic{}.pth".format(EnvName,timestep))
def load(self,EnvName, timestep):
self.actor.load_state_dict(torch.load("./model/{}_actor{}.pth".format(EnvName, timestep), map_location=self.dvc))

View File

@ -1,18 +1,20 @@
from datetime import datetime
import os
import shutil
import argparse
import torch
import gymnasium as gym
from utils import str2bool, Action_adapter, Reward_adapter, evaluate_policy
from PPO import PPO_agent
# fmt: off
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from env import PartitionMazeEnv
# fmt: on
'''Hyperparameter Setting'''
parser = argparse.ArgumentParser()
parser.add_argument('--dvc', type=str, default='',
parser.add_argument('--dvc', type=str, default='cpu',
help='running device: cuda or cpu')
parser.add_argument('--EnvIdex', type=int, default=0,
help='PV1, Lch_Cv2, Humanv4, HCv4, BWv3, BWHv3')
@ -79,9 +81,8 @@ def main():
opt.state_dim = env.observation_space.shape[0]
opt.action_dim = env.action_space.shape[0]
opt.max_action = float(env.action_space.high[0])
opt.max_steps = env._max_episode_steps
print('Env:', EnvName[opt.EnvIdex], ' state_dim:', opt.state_dim, ' action_dim:', opt.action_dim,
' max_a:', opt.max_action, ' min_a:', env.action_space.low[0], 'max_steps', opt.max_steps)
' max_a:', opt.max_action, ' min_a:', env.action_space.low[0])
# Seed Everything
env_seed = opt.seed
@ -121,7 +122,7 @@ def main():
traj_lenth, total_steps = 0, 0
while total_steps < opt.Max_train_steps:
# Do not use opt.seed directly, or it can overfit to opt.seed
s, info = env.reset(seed=env_seed)
s = env.reset(seed=env_seed)
env_seed += 1
done = False

View File

@ -137,7 +137,7 @@ def Reward_adapter(r, EnvIdex):
def evaluate_policy(env, agent, max_action, turns):
total_scores = 0
for j in range(turns):
s, info = env.reset()
s = env.reset()
done = False
action_series = []
while not done: