修改离散环境,连续不动给惩罚
This commit is contained in:
parent
4972306ca7
commit
3dba6e4a53
@ -14,7 +14,7 @@ def evaluate_policy(env, agent, turns = 3):
|
||||
action_series.append(a)
|
||||
total_scores += r
|
||||
s = s_next
|
||||
print('action series: ', np.roudn(action_series, 3))
|
||||
print('action series: ', np.round(action_series, 3))
|
||||
print('state: ', s)
|
||||
return int(total_scores/turns)
|
||||
|
||||
|
26
env_dis.py
26
env_dis.py
@ -209,10 +209,11 @@ class PartitionMazeEnv(gym.Env):
|
||||
|
||||
elif self.phase == 2:
|
||||
# 阶段 2:路径规划(走迷宫)
|
||||
reward = 0
|
||||
|
||||
# 后 4 个动作对应上下左右移动
|
||||
current_car = self.current_car_index
|
||||
current_row, current_col = self.reverse_rectangles[self.car_pos[current_car]]
|
||||
|
||||
# 初始化新的行、列为当前值
|
||||
new_row, new_col = current_row, current_col
|
||||
|
||||
@ -227,21 +228,28 @@ class PartitionMazeEnv(gym.Env):
|
||||
else:
|
||||
# 无效动作,保持原地
|
||||
pass
|
||||
|
||||
# 检查是否移动
|
||||
car_moved = (new_row != current_row or new_col != current_col)
|
||||
# 更新车辆位置
|
||||
self.car_pos[current_car] = self.rectangles[(
|
||||
new_row, new_col)]['center']
|
||||
if new_row != current_row or new_col != current_col:
|
||||
if car_moved:
|
||||
self.car_traj[current_car].append((new_row, new_col))
|
||||
self.step_count += 1
|
||||
self.current_car_index = (
|
||||
self.current_car_index + 1) % self.num_cars
|
||||
|
||||
# 更新访问标记:将新网格标记为已访问
|
||||
self.rectangles[(new_row, new_col)]['is_visited'] = True
|
||||
|
||||
# 观察状态
|
||||
reward = 0
|
||||
# 记录所有车辆一轮中是否移动
|
||||
if self.current_car_index == 0:
|
||||
# 如果是新一轮的开始,初始化移动标记
|
||||
self.cars_moved = [False] * self.num_cars
|
||||
self.cars_moved[current_car] = car_moved
|
||||
# 如果一轮结束,检查是否所有车辆都没有移动
|
||||
if self.current_car_index == (self.num_cars - 1) and not any(self.cars_moved):
|
||||
reward -= 10 # 扣除 10 分奖励
|
||||
|
||||
self.step_count += 1
|
||||
self.current_car_index = (
|
||||
self.current_car_index + 1) % self.num_cars
|
||||
max_regions = (self.CUT_NUM // 2 + 1) ** 2
|
||||
visit_status = np.zeros(max_regions, dtype=np.float32)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user