修改环境
This commit is contained in:
parent
e35dd10326
commit
7ca5ce08b1
@ -50,7 +50,8 @@ def evaluate_policy(env, agent, turns = 3):
|
|||||||
action_series.append(a[0])
|
action_series.append(a[0])
|
||||||
total_scores += r
|
total_scores += r
|
||||||
s = s_next
|
s = s_next
|
||||||
print(np.round(action_series, 3))
|
print('action series: ', np.round(action_series, 3))
|
||||||
|
print('state: {s_next}')
|
||||||
return int(total_scores/turns)
|
return int(total_scores/turns)
|
||||||
|
|
||||||
|
|
||||||
|
60
env.py
60
env.py
@ -53,11 +53,11 @@ class PartitionMazeEnv(gym.Env):
|
|||||||
low=0.0, high=1.0, shape=(1,), dtype=np.float32)
|
low=0.0, high=1.0, shape=(1,), dtype=np.float32)
|
||||||
|
|
||||||
# 定义观察空间为8维向量
|
# 定义观察空间为8维向量
|
||||||
# TODO 返回的状态目前只有位置坐标
|
|
||||||
# 阶段 0 状态:前 4 维表示已决策的切分值(未决策部分为 0)
|
# 阶段 0 状态:前 4 维表示已决策的切分值(未决策部分为 0)
|
||||||
# 阶段 1 状态:车辆位置 (2D)
|
# 阶段 1 状态:区域访问状态向量(长度为(CUT_NUM/2+1)^2)
|
||||||
|
max_regions = (self.CUT_NUM // 2 + 1) ** 2
|
||||||
self.observation_space = spaces.Box(
|
self.observation_space = spaces.Box(
|
||||||
low=0.0, high=1.0, shape=(self.CUT_NUM + 2 * self.num_cars,), dtype=np.float32)
|
low=0.0, high=1.0, shape=(self.CUT_NUM + max_regions,), dtype=np.float32)
|
||||||
|
|
||||||
# 切分阶段相关变量
|
# 切分阶段相关变量
|
||||||
self.col_cuts = [] # 存储竖切位置(c₁, c₂),当值为0时表示不切
|
self.col_cuts = [] # 存储竖切位置(c₁, c₂),当值为0时表示不切
|
||||||
@ -86,9 +86,13 @@ class PartitionMazeEnv(gym.Env):
|
|||||||
self.car_pos = [(self.H / 2, self.W / 2) for _ in range(self.num_cars)]
|
self.car_pos = [(self.H / 2, self.W / 2) for _ in range(self.num_cars)]
|
||||||
self.car_traj = [[] for _ in range(self.num_cars)]
|
self.car_traj = [[] for _ in range(self.num_cars)]
|
||||||
self.current_car_index = 0
|
self.current_car_index = 0
|
||||||
# 状态:前 4 维为 partition_values,其余补 0
|
|
||||||
state = np.concatenate(
|
# 状态:前 4 维为 partition_values,其余为区域访问状态(初始全0)
|
||||||
[self.partition_values, np.zeros(np.array(self.car_pos).flatten().shape[0], dtype=np.float32)])
|
max_regions = (self.CUT_NUM // 2 + 1) ** 2
|
||||||
|
state = np.concatenate([
|
||||||
|
self.partition_values,
|
||||||
|
np.zeros(max_regions, dtype=np.float32)
|
||||||
|
])
|
||||||
return state
|
return state
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
@ -103,8 +107,10 @@ class PartitionMazeEnv(gym.Env):
|
|||||||
self.partition_step += 1
|
self.partition_step += 1
|
||||||
|
|
||||||
# 构造当前状态:前 partition_step 个为已决策值,其余为 0,再补 7 个 0
|
# 构造当前状态:前 partition_step 个为已决策值,其余为 0,再补 7 个 0
|
||||||
state = np.concatenate(
|
state = np.concatenate([
|
||||||
[self.partition_values, np.zeros(np.array(self.car_pos).flatten().shape[0], dtype=np.float32)])
|
self.partition_values,
|
||||||
|
np.zeros((self.CUT_NUM // 2 + 1) ** 2, dtype=np.float32)
|
||||||
|
])
|
||||||
|
|
||||||
# 如果未完成 4 步,则仍处于切分阶段,不发奖励,done 为 False
|
# 如果未完成 4 步,则仍处于切分阶段,不发奖励,done 为 False
|
||||||
if self.partition_step < self.CUT_NUM:
|
if self.partition_step < self.CUT_NUM:
|
||||||
@ -153,8 +159,12 @@ class PartitionMazeEnv(gym.Env):
|
|||||||
|
|
||||||
if not valid_partition:
|
if not valid_partition:
|
||||||
reward = -10000
|
reward = -10000
|
||||||
state = np.concatenate(
|
# 状态:前 4 维为 partition_values,其余为区域访问状态(初始全0)
|
||||||
[self.partition_values, np.zeros(np.array(self.car_pos).flatten().shape[0], dtype=np.float32)])
|
max_regions = (self.CUT_NUM // 2 + 1) ** 2
|
||||||
|
state = np.concatenate([
|
||||||
|
self.partition_values,
|
||||||
|
np.zeros(max_regions, dtype=np.float32)
|
||||||
|
])
|
||||||
return state, reward, True, False, {}
|
return state, reward, True, False, {}
|
||||||
else:
|
else:
|
||||||
# 进入阶段 1:初始化迷宫
|
# 进入阶段 1:初始化迷宫
|
||||||
@ -183,9 +193,19 @@ class PartitionMazeEnv(gym.Env):
|
|||||||
|
|
||||||
# 进入阶段 2:走迷宫
|
# 进入阶段 2:走迷宫
|
||||||
self.phase = 2
|
self.phase = 2
|
||||||
state = np.concatenate(
|
|
||||||
[self.partition_values, np.array(self.car_pos).flatten()]
|
# 构造访问状态向量
|
||||||
)
|
max_regions = (self.CUT_NUM // 2 + 1) ** 2
|
||||||
|
visit_status = np.zeros(max_regions, dtype=np.float32)
|
||||||
|
|
||||||
|
# 将实际区域的访问状态填入向量
|
||||||
|
for i in range(len(self.row_cuts) - 1):
|
||||||
|
for j in range(len(self.col_cuts) - 1):
|
||||||
|
idx = i * (len(self.col_cuts) - 1) + j
|
||||||
|
visit_status[idx] = float(self.rectangles[(i, j)]['is_visited'])
|
||||||
|
for i in range(idx + 1, max_regions):
|
||||||
|
visit_status[i] = 100
|
||||||
|
state = np.concatenate([self.partition_values, visit_status])
|
||||||
return state, reward, False, False, {}
|
return state, reward, False, False, {}
|
||||||
|
|
||||||
elif self.phase == 2:
|
elif self.phase == 2:
|
||||||
@ -250,8 +270,18 @@ class PartitionMazeEnv(gym.Env):
|
|||||||
self.rectangles[(new_row, new_col)]['is_visited'] = True
|
self.rectangles[(new_row, new_col)]['is_visited'] = True
|
||||||
|
|
||||||
# 观察状态
|
# 观察状态
|
||||||
state = np.concatenate(
|
# 构造访问状态向量
|
||||||
[self.partition_values, np.array(self.car_pos).flatten()])
|
max_regions = (self.CUT_NUM // 2 + 1) ** 2
|
||||||
|
visit_status = np.zeros(max_regions, dtype=np.float32)
|
||||||
|
|
||||||
|
# 将实际区域的访问状态填入向量
|
||||||
|
for i in range(len(self.row_cuts) - 1):
|
||||||
|
for j in range(len(self.col_cuts) - 1):
|
||||||
|
idx = i * (len(self.col_cuts) - 1) + j
|
||||||
|
visit_status[idx] = float(self.rectangles[(i, j)]['is_visited'])
|
||||||
|
for i in range(idx + 1, max_regions):
|
||||||
|
visit_status[i] = 100
|
||||||
|
state = np.concatenate([self.partition_values, visit_status])
|
||||||
|
|
||||||
# Episode 终止条件:所有网格均被访问或步数达到上限
|
# Episode 终止条件:所有网格均被访问或步数达到上限
|
||||||
done = all([value['is_visited'] for _, value in self.rectangles.items()]) or (
|
done = all([value['is_visited'] for _, value in self.rectangles.items()]) or (
|
||||||
|
@ -6,7 +6,7 @@ env = PartitionMazeEnv()
|
|||||||
state = env.reset()
|
state = env.reset()
|
||||||
print(state)
|
print(state)
|
||||||
|
|
||||||
action_series = [[0.1], [0.2], [0.4], [0], [0.1]]
|
action_series = [[0], [0], [0.4], [0], [0.1]]
|
||||||
# action_series = [0, 0, 3, 0, 0, 10]
|
# action_series = [0, 0, 3, 0, 0, 10]
|
||||||
|
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
|
Loading…
Reference in New Issue
Block a user