修PPObug
This commit is contained in:
parent
fe4e754cc4
commit
d53eda2570
@ -1,27 +0,0 @@
|
|||||||
"""
|
|
||||||
This file contains the arguments to parse at command line.
|
|
||||||
File main.py will call get_args, which then the arguments
|
|
||||||
will be returned.
|
|
||||||
"""
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
def get_args():
|
|
||||||
"""
|
|
||||||
Description:
|
|
||||||
Parses arguments at command line.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
None
|
|
||||||
|
|
||||||
Return:
|
|
||||||
args - the arguments parsed
|
|
||||||
"""
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
|
|
||||||
parser.add_argument('--mode', dest='mode', type=str, default='train') # can be 'train' or 'test'
|
|
||||||
parser.add_argument('--actor_model', dest='actor_model', type=str, default='') # your actor model filename
|
|
||||||
parser.add_argument('--critic_model', dest='critic_model', type=str, default='') # your critic model filename
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
return args
|
|
10
PPO/env.py
10
PPO/env.py
@ -139,7 +139,7 @@ class PartitionMazeEnv(gym.Env):
|
|||||||
bs_time = self.bs_time_factor * (1 - rho) * d
|
bs_time = self.bs_time_factor * (1 - rho) * d
|
||||||
|
|
||||||
self.rectangles[(i, j)] = {
|
self.rectangles[(i, j)] = {
|
||||||
'center': ((h_boundaries[i] + h_boundaries[i+1]) * self.H / 2, (v_boundaries[j+1] - v_boundaries[j]) * self.W / 2),
|
'center': ((h_boundaries[i] + h_boundaries[i+1]) * self.H / 2, (v_boundaries[j+1] + v_boundaries[j]) * self.W / 2),
|
||||||
'flight_time': flight_time,
|
'flight_time': flight_time,
|
||||||
'bs_time': bs_time,
|
'bs_time': bs_time,
|
||||||
'is_visited': False
|
'is_visited': False
|
||||||
@ -244,9 +244,11 @@ class PartitionMazeEnv(gym.Env):
|
|||||||
# 区域覆盖完毕,根据轨迹计算各车队的执行时间
|
# 区域覆盖完毕,根据轨迹计算各车队的执行时间
|
||||||
T = max([self._compute_motorcade_time(idx)
|
T = max([self._compute_motorcade_time(idx)
|
||||||
for idx in range(self.num_cars)])
|
for idx in range(self.num_cars)])
|
||||||
|
# print(T)
|
||||||
|
# print(self.car_traj)
|
||||||
|
reward += -(T - self.BASE_LINE)
|
||||||
print(T)
|
print(T)
|
||||||
print(self.car_traj)
|
print(self.car_traj)
|
||||||
reward += -(T - self.BASE_LINE)
|
|
||||||
elif done and self.step_count >= self.MAX_STEPS:
|
elif done and self.step_count >= self.MAX_STEPS:
|
||||||
reward += -100
|
reward += -100
|
||||||
|
|
||||||
@ -267,8 +269,8 @@ class PartitionMazeEnv(gym.Env):
|
|||||||
second_point = self.car_traj[idx][i + 1]
|
second_point = self.car_traj[idx][i + 1]
|
||||||
car_time += math.dist(self.rectangles[tuple(first_point)]['center'], self.rectangles[tuple(second_point)]['center']) * \
|
car_time += math.dist(self.rectangles[tuple(first_point)]['center'], self.rectangles[tuple(second_point)]['center']) * \
|
||||||
self.car_time_factor
|
self.car_time_factor
|
||||||
car_time + math.dist(self.rectangles[tuple(self.car_traj[idx][0])]['center'], [self.H, self.W])
|
car_time += math.dist(self.rectangles[tuple(self.car_traj[idx][0])]['center'], [self.H / 2, self.W / 2])
|
||||||
car_time + math.dist(self.rectangles[tuple(self.car_traj[idx][-1])]['center'], [self.H, self.W])
|
car_time += math.dist(self.rectangles[tuple(self.car_traj[idx][-1])]['center'], [self.H / 2, self.W / 2])
|
||||||
|
|
||||||
return max(float(car_time) + flight_time, bs_time)
|
return max(float(car_time) + flight_time, bs_time)
|
||||||
|
|
||||||
|
10
PPO/main.py
10
PPO/main.py
@ -6,8 +6,8 @@
|
|||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
import sys
|
import sys
|
||||||
import torch
|
import torch
|
||||||
|
import argparse
|
||||||
|
|
||||||
from arguments import get_args
|
|
||||||
from ppo import PPO
|
from ppo import PPO
|
||||||
from network import FeedForwardNN
|
from network import FeedForwardNN
|
||||||
from eval_policy import eval_policy
|
from eval_policy import eval_policy
|
||||||
@ -119,5 +119,11 @@ def main(args):
|
|||||||
test(env=env, actor_model=args.actor_model)
|
test(env=env, actor_model=args.actor_model)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
args = get_args() # Parse arguments from command line
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument('--mode', dest='mode', type=str, default='train') # can be 'train' or 'test'
|
||||||
|
parser.add_argument('--actor_model', dest='actor_model', type=str, default='') # your actor model filename
|
||||||
|
parser.add_argument('--critic_model', dest='critic_model', type=str, default='') # your critic model filename
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
main(args)
|
main(args)
|
||||||
|
Loading…
Reference in New Issue
Block a user