修改离散环境,连续不动给惩罚
This commit is contained in:
parent
4972306ca7
commit
3dba6e4a53
@ -14,7 +14,7 @@ def evaluate_policy(env, agent, turns = 3):
|
|||||||
action_series.append(a)
|
action_series.append(a)
|
||||||
total_scores += r
|
total_scores += r
|
||||||
s = s_next
|
s = s_next
|
||||||
print('action series: ', np.roudn(action_series, 3))
|
print('action series: ', np.round(action_series, 3))
|
||||||
print('state: ', s)
|
print('state: ', s)
|
||||||
return int(total_scores/turns)
|
return int(total_scores/turns)
|
||||||
|
|
||||||
|
26
env_dis.py
26
env_dis.py
@ -209,10 +209,11 @@ class PartitionMazeEnv(gym.Env):
|
|||||||
|
|
||||||
elif self.phase == 2:
|
elif self.phase == 2:
|
||||||
# 阶段 2:路径规划(走迷宫)
|
# 阶段 2:路径规划(走迷宫)
|
||||||
|
reward = 0
|
||||||
|
|
||||||
# 后 4 个动作对应上下左右移动
|
# 后 4 个动作对应上下左右移动
|
||||||
current_car = self.current_car_index
|
current_car = self.current_car_index
|
||||||
current_row, current_col = self.reverse_rectangles[self.car_pos[current_car]]
|
current_row, current_col = self.reverse_rectangles[self.car_pos[current_car]]
|
||||||
|
|
||||||
# 初始化新的行、列为当前值
|
# 初始化新的行、列为当前值
|
||||||
new_row, new_col = current_row, current_col
|
new_row, new_col = current_row, current_col
|
||||||
|
|
||||||
@ -227,21 +228,28 @@ class PartitionMazeEnv(gym.Env):
|
|||||||
else:
|
else:
|
||||||
# 无效动作,保持原地
|
# 无效动作,保持原地
|
||||||
pass
|
pass
|
||||||
|
# 检查是否移动
|
||||||
|
car_moved = (new_row != current_row or new_col != current_col)
|
||||||
# 更新车辆位置
|
# 更新车辆位置
|
||||||
self.car_pos[current_car] = self.rectangles[(
|
self.car_pos[current_car] = self.rectangles[(
|
||||||
new_row, new_col)]['center']
|
new_row, new_col)]['center']
|
||||||
if new_row != current_row or new_col != current_col:
|
if car_moved:
|
||||||
self.car_traj[current_car].append((new_row, new_col))
|
self.car_traj[current_car].append((new_row, new_col))
|
||||||
self.step_count += 1
|
|
||||||
self.current_car_index = (
|
|
||||||
self.current_car_index + 1) % self.num_cars
|
|
||||||
|
|
||||||
# 更新访问标记:将新网格标记为已访问
|
# 更新访问标记:将新网格标记为已访问
|
||||||
self.rectangles[(new_row, new_col)]['is_visited'] = True
|
self.rectangles[(new_row, new_col)]['is_visited'] = True
|
||||||
|
|
||||||
# 观察状态
|
# 记录所有车辆一轮中是否移动
|
||||||
reward = 0
|
if self.current_car_index == 0:
|
||||||
|
# 如果是新一轮的开始,初始化移动标记
|
||||||
|
self.cars_moved = [False] * self.num_cars
|
||||||
|
self.cars_moved[current_car] = car_moved
|
||||||
|
# 如果一轮结束,检查是否所有车辆都没有移动
|
||||||
|
if self.current_car_index == (self.num_cars - 1) and not any(self.cars_moved):
|
||||||
|
reward -= 10 # 扣除 10 分奖励
|
||||||
|
|
||||||
|
self.step_count += 1
|
||||||
|
self.current_car_index = (
|
||||||
|
self.current_car_index + 1) % self.num_cars
|
||||||
max_regions = (self.CUT_NUM // 2 + 1) ** 2
|
max_regions = (self.CUT_NUM // 2 + 1) ** 2
|
||||||
visit_status = np.zeros(max_regions, dtype=np.float32)
|
visit_status = np.zeros(max_regions, dtype=np.float32)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user