验证阶段加输出,更新奖励
This commit is contained in:
parent
c96c36d4cd
commit
e35dd10326
2
.gitignore
vendored
2
.gitignore
vendored
@ -9,8 +9,6 @@ __pycache__/
|
|||||||
|
|
||||||
# Pytorch weights
|
# Pytorch weights
|
||||||
weights/
|
weights/
|
||||||
PPO_preTrained/
|
|
||||||
PPO_logs/
|
|
||||||
logs/
|
logs/
|
||||||
|
|
||||||
# Distribution / packaging
|
# Distribution / packaging
|
||||||
|
@ -125,7 +125,7 @@ def main():
|
|||||||
|
|
||||||
'''record & log'''
|
'''record & log'''
|
||||||
if total_steps % opt.eval_interval == 0:
|
if total_steps % opt.eval_interval == 0:
|
||||||
ep_r = evaluate_policy(eval_env, agent, turns=3)
|
ep_r = evaluate_policy(eval_env, agent, turns=1)
|
||||||
if opt.write:
|
if opt.write:
|
||||||
writer.add_scalar(
|
writer.add_scalar(
|
||||||
'ep_r', ep_r, global_step=total_steps)
|
'ep_r', ep_r, global_step=total_steps)
|
||||||
|
34
env.py
34
env.py
@ -41,7 +41,7 @@ class PartitionMazeEnv(gym.Env):
|
|||||||
##############################
|
##############################
|
||||||
self.CUT_NUM = 4 # 横切一半,竖切一半
|
self.CUT_NUM = 4 # 横切一半,竖切一半
|
||||||
self.BASE_LINE = 4000 # 基准时间,通过greedy或者蒙特卡洛计算出来
|
self.BASE_LINE = 4000 # 基准时间,通过greedy或者蒙特卡洛计算出来
|
||||||
self.MAX_STEPS = 200 # 迷宫走法步数上限
|
self.MAX_STEPS = 50 # 迷宫走法步数上限
|
||||||
|
|
||||||
self.phase = 0 # 阶段控制,0:区域划分阶段,1:迷宫初始化阶段,2:走迷宫阶段
|
self.phase = 0 # 阶段控制,0:区域划分阶段,1:迷宫初始化阶段,2:走迷宫阶段
|
||||||
self.partition_step = 0 # 区域划分阶段步数,范围 0~4
|
self.partition_step = 0 # 区域划分阶段步数,范围 0~4
|
||||||
@ -159,7 +159,7 @@ class PartitionMazeEnv(gym.Env):
|
|||||||
else:
|
else:
|
||||||
# 进入阶段 1:初始化迷宫
|
# 进入阶段 1:初始化迷宫
|
||||||
self.phase = 1
|
self.phase = 1
|
||||||
reward = 10
|
reward = 100
|
||||||
|
|
||||||
# 构建反向索引,方便后续计算
|
# 构建反向索引,方便后续计算
|
||||||
self.reverse_rectangles = {v['center']: k for k, v in self.rectangles.items()}
|
self.reverse_rectangles = {v['center']: k for k, v in self.rectangles.items()}
|
||||||
@ -194,6 +194,8 @@ class PartitionMazeEnv(gym.Env):
|
|||||||
# 查表,找出当前车辆所在的网格
|
# 查表,找出当前车辆所在的网格
|
||||||
current_row, current_col = self.reverse_rectangles[self.car_pos[current_car]]
|
current_row, current_col = self.reverse_rectangles[self.car_pos[current_car]]
|
||||||
|
|
||||||
|
reward = 0
|
||||||
|
|
||||||
# 当前动作 a 为 1 维连续动作,映射到四个方向
|
# 当前动作 a 为 1 维连续动作,映射到四个方向
|
||||||
if a < 0.2:
|
if a < 0.2:
|
||||||
move_dir = 'up'
|
move_dir = 'up'
|
||||||
@ -209,16 +211,31 @@ class PartitionMazeEnv(gym.Env):
|
|||||||
# 初始化新的行、列为当前值
|
# 初始化新的行、列为当前值
|
||||||
new_row, new_col = current_row, current_col
|
new_row, new_col = current_row, current_col
|
||||||
|
|
||||||
if move_dir == 'up' and current_row > 0:
|
if move_dir == 'up':
|
||||||
|
if current_row > 0:
|
||||||
new_row = current_row - 1
|
new_row = current_row - 1
|
||||||
elif move_dir == 'down' and current_row < len(self.row_cuts) - 2:
|
else: # 错误的移动给一些惩罚?
|
||||||
|
new_row = current_row
|
||||||
|
# reward -= 10
|
||||||
|
elif move_dir == 'down':
|
||||||
|
if current_row < len(self.row_cuts) - 2:
|
||||||
new_row = current_row + 1
|
new_row = current_row + 1
|
||||||
elif move_dir == 'left' and current_col > 0:
|
else:
|
||||||
|
new_row = current_row
|
||||||
|
# reward -= 10
|
||||||
|
elif move_dir == 'left':
|
||||||
|
if current_col > 0:
|
||||||
new_col = current_col - 1
|
new_col = current_col - 1
|
||||||
elif move_dir == 'right' and current_col < len(self.col_cuts) - 2:
|
else:
|
||||||
|
new_col = current_col
|
||||||
|
# reward -= 10
|
||||||
|
elif move_dir == 'right':
|
||||||
|
if current_col < len(self.col_cuts) - 2:
|
||||||
new_col = current_col + 1
|
new_col = current_col + 1
|
||||||
|
else:
|
||||||
|
new_col = current_col
|
||||||
|
# reward -= 10
|
||||||
# 如果移动不合法,或者动作为stay,则保持原位置
|
# 如果移动不合法,或者动作为stay,则保持原位置
|
||||||
# TODO 移动不合法,加一些惩罚
|
|
||||||
|
|
||||||
# 更新车辆位置
|
# 更新车辆位置
|
||||||
self.car_pos[current_car] = self.rectangles[(
|
self.car_pos[current_car] = self.rectangles[(
|
||||||
@ -235,7 +252,6 @@ class PartitionMazeEnv(gym.Env):
|
|||||||
# 观察状态
|
# 观察状态
|
||||||
state = np.concatenate(
|
state = np.concatenate(
|
||||||
[self.partition_values, np.array(self.car_pos).flatten()])
|
[self.partition_values, np.array(self.car_pos).flatten()])
|
||||||
reward = 0
|
|
||||||
|
|
||||||
# Episode 终止条件:所有网格均被访问或步数达到上限
|
# Episode 终止条件:所有网格均被访问或步数达到上限
|
||||||
done = all([value['is_visited'] for _, value in self.rectangles.items()]) or (
|
done = all([value['is_visited'] for _, value in self.rectangles.items()]) or (
|
||||||
@ -247,7 +263,7 @@ class PartitionMazeEnv(gym.Env):
|
|||||||
# print(T)
|
# print(T)
|
||||||
# print(self.partition_values)
|
# print(self.partition_values)
|
||||||
# print(self.car_traj)
|
# print(self.car_traj)
|
||||||
reward += self.BASE_LINE / T * 100
|
reward += self.BASE_LINE / T * 1000
|
||||||
elif done and self.step_count >= self.MAX_STEPS:
|
elif done and self.step_count >= self.MAX_STEPS:
|
||||||
reward += -1000
|
reward += -1000
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user