修改env数据结构

This commit is contained in:
weixin_46229132 2025-03-13 15:09:58 +08:00
parent 1f18d9d96f
commit aecd86b245

View File

@ -54,8 +54,8 @@ class PartitionMazeEnv(gym.Env):
low=0.0, high=1.0, shape=(4 + 2 * self.num_cars,), dtype=np.float32)
# 切分阶段相关变量
self.vertical_cuts = [] # 存储竖切位置c₁, c₂当值为0时表示不切
self.horizontal_cuts = [] # 存储横切位置r₁, r₂
self.col_cuts = [] # 存储竖切位置c₁, c₂当值为0时表示不切
self.row_cuts = [] # 存储横切位置r₁, r₂
self.init_maze_step = 0
@ -64,7 +64,7 @@ class PartitionMazeEnv(gym.Env):
self.BASE_LINE = 3500.0 # 基准时间通过greedy或者蒙特卡洛计算出来
self.step_count = 0
self.rectangles = {}
self.car_pos = [[0.5, 0.5] for _ in range(self.num_cars)]
self.car_pos = [(0, 0) for _ in range(self.num_cars)]
self.car_traj = [[] for _ in range(self.num_cars)]
self.current_car_index = 0
@ -73,13 +73,13 @@ class PartitionMazeEnv(gym.Env):
self.phase = 0
self.partition_step = 0
self.partition_values = np.zeros(4, dtype=np.float32)
self.vertical_cuts = []
self.horizontal_cuts = []
self.col_cuts = []
self.row_cuts = []
self.init_maze_step = 0
self.region_centers = []
self.step_count = 0
self.rectangles = {}
self.car_pos = [[0.5, 0.5] for _ in range(self.num_cars)]
self.car_pos = [(0, 0) for _ in range(self.num_cars)]
self.car_traj = [[] for _ in range(self.num_cars)]
self.current_car_index = 0
# 状态:前 4 维为 partition_values其余补 0
@ -112,19 +112,19 @@ class PartitionMazeEnv(gym.Env):
self.partition_values) // 2] if v > 0))
horiz = sorted(set(v for v in self.partition_values[len(
self.partition_values) // 2:] if v > 0))
self.vertical_cuts = vert if vert else []
self.horizontal_cuts = horiz if horiz else []
vertical_cuts = vert if vert else []
horizontal_cuts = horiz if horiz else []
# 边界:始终包含 0 和 1
v_boundaries = [0.0] + self.vertical_cuts + [1.0]
h_boundaries = [0.0] + self.horizontal_cuts + [1.0]
self.col_cuts = [0.0] + vertical_cuts + [1.0]
self.row_cuts = [0.0] + horizontal_cuts + [1.0]
# 判断分区是否合理,并计算各个分区的任务卸载率ρ
valid_partition = True
for i in range(len(h_boundaries) - 1):
for j in range(len(v_boundaries) - 1):
d = (v_boundaries[j+1] - v_boundaries[j]) * self.W * \
(h_boundaries[i] + h_boundaries[i+1]) * self.H
for i in range(len(self.row_cuts) - 1):
for j in range(len(self.col_cuts) - 1):
d = (self.col_cuts[j+1] - self.col_cuts[j]) * self.W * \
(self.row_cuts[i] + self.row_cuts[i+1]) * self.H
rho_time_limit = (self.flight_time_factor - self.trans_time_factor) / \
(self.comp_time_factor - self.trans_time_factor)
rho_energy_limit = (self.battery_energy_capacity - self.flight_energy_factor * d - self.trans_energy_factor * d) / \
@ -139,7 +139,7 @@ class PartitionMazeEnv(gym.Env):
bs_time = self.bs_time_factor * (1 - rho) * d
self.rectangles[(i, j)] = {
'center': ((v_boundaries[j+1] + v_boundaries[j]) * self.W / 2, (h_boundaries[i] + h_boundaries[i+1]) * self.H / 2),
'center': ((self.row_cuts[i] + self.row_cuts[i+1]) * self.H / 2, (self.col_cuts[j+1] + self.col_cuts[j]) * self.W / 2),
'flight_time': flight_time,
'bs_time': bs_time,
'is_visited': False
@ -148,7 +148,7 @@ class PartitionMazeEnv(gym.Env):
break
if not valid_partition:
reward = -100
reward = -10000
state = np.concatenate(
[self.partition_values, np.zeros(np.array(self.car_pos).flatten().shape[0], dtype=np.float32)])
return state, reward, True, False, {}
@ -157,9 +157,9 @@ class PartitionMazeEnv(gym.Env):
# 进入阶段 1初始化迷宫
self.phase = 1
# 存储切分边界,供后续网格映射使用
self.v_boundaries = v_boundaries
self.h_boundaries = h_boundaries
# 所有车队从整个区域的中心出发
self.car_pos = [(len(self.row_cuts) - 2 / 2, len(self.col_cuts) -2 / 2)
for _ in range(self.num_cars)]
state = np.concatenate(
[self.partition_values, np.array(self.car_pos).flatten()])
return state, reward, False, False, {}
@ -167,17 +167,17 @@ class PartitionMazeEnv(gym.Env):
elif self.phase == 1:
# 阶段 1初始化迷宫让多个车辆从区域中心出发前往划分区域的中心点
# 确保 action 的值在 [0, 1],然后映射到 0~(num_regions-1) 的索引
num_regions = (len(self.v_boundaries) - 1) * \
(len(self.h_boundaries) - 1)
num_regions = (len(self.col_cuts) - 1) * \
(len(self.row_cuts) - 1)
target_region_index = int(np.floor(a * num_regions))
target_region_index = np.clip(
target_region_index, 0, num_regions - 1)
# 将index映射到笛卡尔坐标
coord = [target_region_index // (len(self.v_boundaries) - 1),
target_region_index % (len(self.v_boundaries) - 1)]
coord = (target_region_index // (len(self.col_cuts) - 1),
target_region_index % (len(self.col_cuts) - 1))
self.car_pos[self.init_maze_step] = coord
self.car_traj[self.init_maze_step].append(coord)
self.rectangles[tuple(coord)]['is_visited'] = True
self.rectangles[coord]['is_visited'] = True
# 计数
self.init_maze_step += 1
@ -210,21 +210,21 @@ class PartitionMazeEnv(gym.Env):
# 初始化新的行、列为当前值
new_row, new_col = current_row, current_col
if move_dir == 'up' and current_row < len(self.h_boundaries) - 2:
new_row = current_row + 1
elif move_dir == 'down' and current_row > 0:
if move_dir == 'up' and current_row > 0:
new_row = current_row - 1
elif move_dir == 'down' and current_row < len(self.row_cuts) - 2:
new_row = current_row + 1
elif move_dir == 'left' and current_col > 0:
new_col = current_col - 1
elif move_dir == 'right' and current_col < len(self.v_boundaries) - 2:
elif move_dir == 'right' and current_col < len(self.col_cuts) - 2:
new_col = current_col + 1
# 如果移动不合法或者动作为stay则保持原位置
# TODO 移动不合法,加一些惩罚
# 更新车辆位置
self.car_pos[current_car] = [new_row, new_col]
self.car_pos[current_car] = (new_row, new_col)
if new_row != current_row or new_col != current_col:
self.car_traj[current_car].append([new_row, new_col])
self.car_traj[current_car].append((new_row, new_col))
self.step_count += 1
self.current_car_index = (
self.current_car_index + 1) % self.num_cars
@ -260,15 +260,15 @@ class PartitionMazeEnv(gym.Env):
# 计算车的移动时间,首先在轨迹的首尾添加上大区域中心
car_time = 0
# self.car_traj[idx].append([0.5, 0.5])
# self.car_traj[idx].insert(0, [0.5, 0.5])
for i in range(len(self.car_traj[idx]) - 1):
first_point = self.car_traj[idx][i]
second_point = self.car_traj[idx][i + 1]
car_time += math.dist(self.rectangles[tuple(first_point)]['center'], self.rectangles[tuple(second_point)]['center']) * \
car_time += math.dist(self.rectangles[first_point]['center'], self.rectangles[second_point]['center']) * \
self.car_time_factor
car_time += math.dist(self.rectangles[tuple(self.car_traj[idx][0])]['center'], [self.W / 2, self.H / 2]) * self.car_time_factor
car_time += math.dist(self.rectangles[tuple(self.car_traj[idx][-1])]['center'], [self.W / 2, self.H / 2]) * self.car_time_factor
car_time += math.dist(self.rectangles[self.car_traj[idx][0]]['center'], [
self.H / 2, self.W / 2]) * self.car_time_factor
car_time += math.dist(self.rectangles[self.car_traj[idx][-1]]['center'], [
self.H / 2, self.W / 2]) * self.car_time_factor
return max(float(car_time) + flight_time, bs_time)