验证阶段加输出,更新奖励
This commit is contained in:
parent
c96c36d4cd
commit
e35dd10326
2
.gitignore
vendored
2
.gitignore
vendored
@ -9,8 +9,6 @@ __pycache__/
|
||||
|
||||
# Pytorch weights
|
||||
weights/
|
||||
PPO_preTrained/
|
||||
PPO_logs/
|
||||
logs/
|
||||
|
||||
# Distribution / packaging
|
||||
|
@ -125,7 +125,7 @@ def main():
|
||||
|
||||
'''record & log'''
|
||||
if total_steps % opt.eval_interval == 0:
|
||||
ep_r = evaluate_policy(eval_env, agent, turns=3)
|
||||
ep_r = evaluate_policy(eval_env, agent, turns=1)
|
||||
if opt.write:
|
||||
writer.add_scalar(
|
||||
'ep_r', ep_r, global_step=total_steps)
|
||||
|
42
env.py
42
env.py
@ -41,7 +41,7 @@ class PartitionMazeEnv(gym.Env):
|
||||
##############################
|
||||
self.CUT_NUM = 4 # 横切一半,竖切一半
|
||||
self.BASE_LINE = 4000 # 基准时间,通过greedy或者蒙特卡洛计算出来
|
||||
self.MAX_STEPS = 200 # 迷宫走法步数上限
|
||||
self.MAX_STEPS = 50 # 迷宫走法步数上限
|
||||
|
||||
self.phase = 0 # 阶段控制,0:区域划分阶段,1:迷宫初始化阶段,2:走迷宫阶段
|
||||
self.partition_step = 0 # 区域划分阶段步数,范围 0~4
|
||||
@ -159,7 +159,7 @@ class PartitionMazeEnv(gym.Env):
|
||||
else:
|
||||
# 进入阶段 1:初始化迷宫
|
||||
self.phase = 1
|
||||
reward = 10
|
||||
reward = 100
|
||||
|
||||
# 构建反向索引,方便后续计算
|
||||
self.reverse_rectangles = {v['center']: k for k, v in self.rectangles.items()}
|
||||
@ -194,6 +194,8 @@ class PartitionMazeEnv(gym.Env):
|
||||
# 查表,找出当前车辆所在的网格
|
||||
current_row, current_col = self.reverse_rectangles[self.car_pos[current_car]]
|
||||
|
||||
reward = 0
|
||||
|
||||
# 当前动作 a 为 1 维连续动作,映射到四个方向
|
||||
if a < 0.2:
|
||||
move_dir = 'up'
|
||||
@ -209,16 +211,31 @@ class PartitionMazeEnv(gym.Env):
|
||||
# 初始化新的行、列为当前值
|
||||
new_row, new_col = current_row, current_col
|
||||
|
||||
if move_dir == 'up' and current_row > 0:
|
||||
new_row = current_row - 1
|
||||
elif move_dir == 'down' and current_row < len(self.row_cuts) - 2:
|
||||
new_row = current_row + 1
|
||||
elif move_dir == 'left' and current_col > 0:
|
||||
new_col = current_col - 1
|
||||
elif move_dir == 'right' and current_col < len(self.col_cuts) - 2:
|
||||
new_col = current_col + 1
|
||||
if move_dir == 'up':
|
||||
if current_row > 0:
|
||||
new_row = current_row - 1
|
||||
else: # 错误的移动给一些惩罚?
|
||||
new_row = current_row
|
||||
# reward -= 10
|
||||
elif move_dir == 'down':
|
||||
if current_row < len(self.row_cuts) - 2:
|
||||
new_row = current_row + 1
|
||||
else:
|
||||
new_row = current_row
|
||||
# reward -= 10
|
||||
elif move_dir == 'left':
|
||||
if current_col > 0:
|
||||
new_col = current_col - 1
|
||||
else:
|
||||
new_col = current_col
|
||||
# reward -= 10
|
||||
elif move_dir == 'right':
|
||||
if current_col < len(self.col_cuts) - 2:
|
||||
new_col = current_col + 1
|
||||
else:
|
||||
new_col = current_col
|
||||
# reward -= 10
|
||||
# 如果移动不合法,或者动作为stay,则保持原位置
|
||||
# TODO 移动不合法,加一些惩罚
|
||||
|
||||
# 更新车辆位置
|
||||
self.car_pos[current_car] = self.rectangles[(
|
||||
@ -235,7 +252,6 @@ class PartitionMazeEnv(gym.Env):
|
||||
# 观察状态
|
||||
state = np.concatenate(
|
||||
[self.partition_values, np.array(self.car_pos).flatten()])
|
||||
reward = 0
|
||||
|
||||
# Episode 终止条件:所有网格均被访问或步数达到上限
|
||||
done = all([value['is_visited'] for _, value in self.rectangles.items()]) or (
|
||||
@ -247,7 +263,7 @@ class PartitionMazeEnv(gym.Env):
|
||||
# print(T)
|
||||
# print(self.partition_values)
|
||||
# print(self.car_traj)
|
||||
reward += self.BASE_LINE / T * 100
|
||||
reward += self.BASE_LINE / T * 1000
|
||||
elif done and self.step_count >= self.MAX_STEPS:
|
||||
reward += -1000
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user