修PPObug

This commit is contained in:
weixin_46229132 2025-03-12 16:09:19 +08:00
parent fe4e754cc4
commit d53eda2570
3 changed files with 14 additions and 33 deletions

View File

@ -1,27 +0,0 @@
"""
This file contains the arguments to parse at command line.
File main.py will call get_args, which then the arguments
will be returned.
"""
import argparse
def get_args():
"""
Description:
Parses arguments at command line.
Parameters:
None
Return:
args - the arguments parsed
"""
parser = argparse.ArgumentParser()
parser.add_argument('--mode', dest='mode', type=str, default='train') # can be 'train' or 'test'
parser.add_argument('--actor_model', dest='actor_model', type=str, default='') # your actor model filename
parser.add_argument('--critic_model', dest='critic_model', type=str, default='') # your critic model filename
args = parser.parse_args()
return args

View File

@ -139,7 +139,7 @@ class PartitionMazeEnv(gym.Env):
bs_time = self.bs_time_factor * (1 - rho) * d bs_time = self.bs_time_factor * (1 - rho) * d
self.rectangles[(i, j)] = { self.rectangles[(i, j)] = {
'center': ((h_boundaries[i] + h_boundaries[i+1]) * self.H / 2, (v_boundaries[j+1] - v_boundaries[j]) * self.W / 2), 'center': ((h_boundaries[i] + h_boundaries[i+1]) * self.H / 2, (v_boundaries[j+1] + v_boundaries[j]) * self.W / 2),
'flight_time': flight_time, 'flight_time': flight_time,
'bs_time': bs_time, 'bs_time': bs_time,
'is_visited': False 'is_visited': False
@ -244,9 +244,11 @@ class PartitionMazeEnv(gym.Env):
# 区域覆盖完毕,根据轨迹计算各车队的执行时间 # 区域覆盖完毕,根据轨迹计算各车队的执行时间
T = max([self._compute_motorcade_time(idx) T = max([self._compute_motorcade_time(idx)
for idx in range(self.num_cars)]) for idx in range(self.num_cars)])
# print(T)
# print(self.car_traj)
reward += -(T - self.BASE_LINE)
print(T) print(T)
print(self.car_traj) print(self.car_traj)
reward += -(T - self.BASE_LINE)
elif done and self.step_count >= self.MAX_STEPS: elif done and self.step_count >= self.MAX_STEPS:
reward += -100 reward += -100
@ -267,8 +269,8 @@ class PartitionMazeEnv(gym.Env):
second_point = self.car_traj[idx][i + 1] second_point = self.car_traj[idx][i + 1]
car_time += math.dist(self.rectangles[tuple(first_point)]['center'], self.rectangles[tuple(second_point)]['center']) * \ car_time += math.dist(self.rectangles[tuple(first_point)]['center'], self.rectangles[tuple(second_point)]['center']) * \
self.car_time_factor self.car_time_factor
car_time + math.dist(self.rectangles[tuple(self.car_traj[idx][0])]['center'], [self.H, self.W]) car_time += math.dist(self.rectangles[tuple(self.car_traj[idx][0])]['center'], [self.H / 2, self.W / 2])
car_time + math.dist(self.rectangles[tuple(self.car_traj[idx][-1])]['center'], [self.H, self.W]) car_time += math.dist(self.rectangles[tuple(self.car_traj[idx][-1])]['center'], [self.H / 2, self.W / 2])
return max(float(car_time) + flight_time, bs_time) return max(float(car_time) + flight_time, bs_time)

View File

@ -6,8 +6,8 @@
import gymnasium as gym import gymnasium as gym
import sys import sys
import torch import torch
import argparse
from arguments import get_args
from ppo import PPO from ppo import PPO
from network import FeedForwardNN from network import FeedForwardNN
from eval_policy import eval_policy from eval_policy import eval_policy
@ -119,5 +119,11 @@ def main(args):
test(env=env, actor_model=args.actor_model) test(env=env, actor_model=args.actor_model)
if __name__ == '__main__': if __name__ == '__main__':
args = get_args() # Parse arguments from command line parser = argparse.ArgumentParser()
parser.add_argument('--mode', dest='mode', type=str, default='train') # can be 'train' or 'test'
parser.add_argument('--actor_model', dest='actor_model', type=str, default='') # your actor model filename
parser.add_argument('--critic_model', dest='critic_model', type=str, default='') # your critic model filename
args = parser.parse_args()
main(args) main(args)