修PPObug
This commit is contained in:
parent
fe4e754cc4
commit
d53eda2570
@ -1,27 +0,0 @@
|
||||
"""
|
||||
This file contains the arguments to parse at command line.
|
||||
File main.py will call get_args, which then the arguments
|
||||
will be returned.
|
||||
"""
|
||||
import argparse
|
||||
|
||||
def get_args():
|
||||
"""
|
||||
Description:
|
||||
Parses arguments at command line.
|
||||
|
||||
Parameters:
|
||||
None
|
||||
|
||||
Return:
|
||||
args - the arguments parsed
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('--mode', dest='mode', type=str, default='train') # can be 'train' or 'test'
|
||||
parser.add_argument('--actor_model', dest='actor_model', type=str, default='') # your actor model filename
|
||||
parser.add_argument('--critic_model', dest='critic_model', type=str, default='') # your critic model filename
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
10
PPO/env.py
10
PPO/env.py
@ -139,7 +139,7 @@ class PartitionMazeEnv(gym.Env):
|
||||
bs_time = self.bs_time_factor * (1 - rho) * d
|
||||
|
||||
self.rectangles[(i, j)] = {
|
||||
'center': ((h_boundaries[i] + h_boundaries[i+1]) * self.H / 2, (v_boundaries[j+1] - v_boundaries[j]) * self.W / 2),
|
||||
'center': ((h_boundaries[i] + h_boundaries[i+1]) * self.H / 2, (v_boundaries[j+1] + v_boundaries[j]) * self.W / 2),
|
||||
'flight_time': flight_time,
|
||||
'bs_time': bs_time,
|
||||
'is_visited': False
|
||||
@ -244,9 +244,11 @@ class PartitionMazeEnv(gym.Env):
|
||||
# 区域覆盖完毕,根据轨迹计算各车队的执行时间
|
||||
T = max([self._compute_motorcade_time(idx)
|
||||
for idx in range(self.num_cars)])
|
||||
# print(T)
|
||||
# print(self.car_traj)
|
||||
reward += -(T - self.BASE_LINE)
|
||||
print(T)
|
||||
print(self.car_traj)
|
||||
reward += -(T - self.BASE_LINE)
|
||||
elif done and self.step_count >= self.MAX_STEPS:
|
||||
reward += -100
|
||||
|
||||
@ -267,8 +269,8 @@ class PartitionMazeEnv(gym.Env):
|
||||
second_point = self.car_traj[idx][i + 1]
|
||||
car_time += math.dist(self.rectangles[tuple(first_point)]['center'], self.rectangles[tuple(second_point)]['center']) * \
|
||||
self.car_time_factor
|
||||
car_time + math.dist(self.rectangles[tuple(self.car_traj[idx][0])]['center'], [self.H, self.W])
|
||||
car_time + math.dist(self.rectangles[tuple(self.car_traj[idx][-1])]['center'], [self.H, self.W])
|
||||
car_time += math.dist(self.rectangles[tuple(self.car_traj[idx][0])]['center'], [self.H / 2, self.W / 2])
|
||||
car_time += math.dist(self.rectangles[tuple(self.car_traj[idx][-1])]['center'], [self.H / 2, self.W / 2])
|
||||
|
||||
return max(float(car_time) + flight_time, bs_time)
|
||||
|
||||
|
10
PPO/main.py
10
PPO/main.py
@ -6,8 +6,8 @@
|
||||
import gymnasium as gym
|
||||
import sys
|
||||
import torch
|
||||
import argparse
|
||||
|
||||
from arguments import get_args
|
||||
from ppo import PPO
|
||||
from network import FeedForwardNN
|
||||
from eval_policy import eval_policy
|
||||
@ -119,5 +119,11 @@ def main(args):
|
||||
test(env=env, actor_model=args.actor_model)
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = get_args() # Parse arguments from command line
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('--mode', dest='mode', type=str, default='train') # can be 'train' or 'test'
|
||||
parser.add_argument('--actor_model', dest='actor_model', type=str, default='') # your actor model filename
|
||||
parser.add_argument('--critic_model', dest='critic_model', type=str, default='') # your critic model filename
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
Loading…
Reference in New Issue
Block a user