HPCC2025/env_partion.py

204 lines
7.6 KiB
Python
Raw Normal View History

2025-03-28 21:37:31 +08:00
import gymnasium as gym
from gymnasium import spaces
import numpy as np
import yaml
import math
from mTSP_solver import mTSP
2025-03-29 11:43:04 +08:00
from GA.ga import GA
2025-03-28 21:37:31 +08:00
class PartitionEnv(gym.Env):
"""
自定义环境分为两阶段
区域切分每一次切分都是(0, 1)之间的连续值
"""
def __init__(self, config=None):
super(PartitionEnv, self).__init__()
##############################
# 可能需要手动修改的超参数
##############################
2025-03-29 12:00:26 +08:00
self.params = 'params2'
2025-03-29 16:28:30 +08:00
self.ORI_ROW_CUTS = [0, 0.2, 0.4, 0.7, 1]
self.ORI_COL_CUTS = [0, 0.5, 1]
2025-03-29 12:00:26 +08:00
self.CUT_NUM = 4
self.ROW_CUT_LIMIT = 3
2025-03-28 21:37:31 +08:00
self.COL_CUT_LIMIT = 1
2025-03-29 12:00:26 +08:00
self.BASE_LINE = 10000
2025-03-29 10:48:47 +08:00
self.mTSP_STEPS = 10000
2025-03-28 21:37:31 +08:00
2025-03-29 16:28:30 +08:00
# 定义动作空间:全部动作均为 1 维连续 [0,1]
self.action_space = spaces.Box(
low=0.0, high=1.0, shape=(1,), dtype=np.float32)
# 定义观察空间为8维向量
# 前 4 维表示已决策的切分值(未决策部分为 0
self.observation_space = spaces.Box(
low=0.0, high=1.0, shape=(self.CUT_NUM + 4,), dtype=np.float32)
self.partition_step = 0
self.ori_row_cuts = self.ORI_ROW_CUTS[:]
self.ori_col_cuts = self.ORI_COL_CUTS[:]
self.rectangles = []
2025-03-28 21:37:31 +08:00
# 车队参数设置
with open(self.params + '.yml', 'r', encoding='utf-8') as file:
params = yaml.safe_load(file)
self.H = params['H']
self.W = params['W']
self.center = (self.H/2, self.W/2)
self.num_cars = params['num_cars']
self.flight_time_factor = params['flight_time_factor']
self.comp_time_factor = params['comp_time_factor']
self.trans_time_factor = params['trans_time_factor']
self.car_time_factor = params['car_time_factor']
self.bs_time_factor = params['bs_time_factor']
self.flight_energy_factor = params['flight_energy_factor']
self.comp_energy_factor = params['comp_energy_factor']
self.trans_energy_factor = params['trans_energy_factor']
self.battery_energy_capacity = params['battery_energy_capacity']
def reset(self, seed=None, options=None):
# 重置所有变量回到切分阶段phase 0
self.phase = 0
self.partition_step = 0
2025-03-29 16:28:30 +08:00
self.ori_row_cuts = self.ORI_ROW_CUTS[:]
self.ori_col_cuts = self.ORI_COL_CUTS[:]
2025-03-28 21:37:31 +08:00
self.rectangles = []
# 状态:前 4 维为 partition_values其余为区域访问状态初始全0
2025-03-29 16:28:30 +08:00
state = np.array(self.ori_row_cuts + self.ori_col_cuts)
2025-03-28 21:37:31 +08:00
return state
def step(self, action):
# 在所有阶段动作均为 1 维连续动作,取 action[0]
2025-03-29 16:28:30 +08:00
adjust = float(action[0])
valid_adjust = True
2025-03-28 21:37:31 +08:00
2025-03-29 16:28:30 +08:00
if self.partition_step < self.ROW_CUT_LIMIT:
row_cut = self.ori_row_cuts[self.partition_step + 1]
new_row_cut = row_cut + adjust
self.ori_row_cuts[self.partition_step + 1] = new_row_cut
2025-03-28 21:37:31 +08:00
2025-03-29 16:28:30 +08:00
if self.ori_row_cuts[self.partition_step] < new_row_cut < self.ori_row_cuts[self.partition_step + 2]:
pass
else:
valid_adjust = False
reward = -100
2025-03-28 21:37:31 +08:00
else:
2025-03-29 16:28:30 +08:00
col_idx = self.partition_step - self.ROW_CUT_LIMIT
col_cut = self.ori_col_cuts[col_idx + 1]
new_col_cut = col_cut + adjust
self.ori_col_cuts[col_idx + 1] = new_col_cut
2025-03-28 21:37:31 +08:00
2025-03-29 16:28:30 +08:00
if self.ori_col_cuts[col_idx] < new_col_cut < self.ori_col_cuts[col_idx + 2]:
pass
else:
valid_adjust = False
reward = -100
2025-03-28 21:37:31 +08:00
2025-03-29 16:28:30 +08:00
self.partition_step += 1
2025-03-28 21:37:31 +08:00
2025-03-29 16:28:30 +08:00
state = np.array(self.ori_row_cuts + self.ori_col_cuts)
# 出现无效调整,直接结束
if not valid_adjust:
return state, reward, True, False, {}
2025-03-29 16:53:03 +08:00
# 调整合理,计算当前时间
2025-03-29 16:28:30 +08:00
else:
2025-03-29 16:53:03 +08:00
rectangles = self.if_valid_partition()
if not rectangles:
reward = -10
return state, reward, True, False, {}
2025-03-28 21:37:31 +08:00
else:
2025-03-29 16:53:03 +08:00
# 继续进行路径规划
# 使用遗传算法解多旅行商
best_time, best_path = self.ga_solver(rectangles)
# print(best_time)
# print(best_path)
reward = self.BASE_LINE / best_time
if self.partition_step < self.CUT_NUM:
done = False
2025-03-29 16:28:30 +08:00
else:
2025-03-29 16:53:03 +08:00
done = True
reward = reward * 3
return state, reward, done, False, best_path
def if_valid_partition(self):
rectangles = []
for i in range(len(self.ori_row_cuts) - 1):
for j in range(len(self.ori_col_cuts) - 1):
d = (self.ori_col_cuts[j+1] - self.ori_col_cuts[j]) * self.W * \
(self.ori_row_cuts[i+1] -
self.ori_row_cuts[i]) * self.H
rho_time_limit = (self.flight_time_factor - self.trans_time_factor) / \
(self.comp_time_factor - self.trans_time_factor)
rho_energy_limit = (self.battery_energy_capacity - self.flight_energy_factor * d - self.trans_energy_factor * d) / \
(self.comp_energy_factor * d -
self.trans_energy_factor * d)
if rho_energy_limit < 0:
return []
rho = min(rho_time_limit, rho_energy_limit)
flight_time = self.flight_time_factor * d
bs_time = self.bs_time_factor * (1 - rho) * d
rectangles.append({
'center': ((self.ori_row_cuts[i] + self.ori_row_cuts[i+1]) * self.H / 2, (self.ori_col_cuts[j+1] + self.ori_col_cuts[j]) * self.W / 2),
'flight_time': flight_time,
'bs_time': bs_time,
})
return rectangles
# def q_learning_solver(self):
# 使用q_learning解多旅行商
# cities: [[x1, x2, x3...], [y1, y2, y3...]] 城市坐标
# rec_center_lt = [rec_info['center']
# for rec_info in rectangles]
# cities = np.column_stack(rec_center_lt)
# cities = np.column_stack((self.center, cities))
# center_idx = []
# for i in range(self.num_cars - 1):
# cities = np.column_stack((cities, self.center))
# center_idx.append(cities.shape[1] - 1)
# tsp = mTSP(params=self.params, num_cities=cities.shape[1], cities=cities, num_cars=self.num_cars,
# center_idx=center_idx, rectangles=rectangles)
# best_time, best_path = tsp.train(self.mTSP_STEPS)
def ga_solver(self, rectangles):
cities = [self.center]
for rec in rectangles:
cities.append(rec['center'])
cities = np.array(cities)
center_idx = [0]
for i in range(self.num_cars - 1):
cities = np.row_stack((cities, self.center))
center_idx.append(cities.shape[0] - 1)
ga = GA(num_drones=self.num_cars, num_city=cities.shape[0], num_total=20,
data=cities, to_process_idx=center_idx, rectangles=rectangles)
best_path, best_time = ga.run()
return best_time, best_path
2025-03-28 21:37:31 +08:00
def render(self):
if self.phase == 1:
print("Phase 1: Initialize maze environment.")
print(f"Partition values so far: {self.partition_values}")
print(f"Motorcade positon: {self.car_pos}")
# input('1111')
elif self.phase == 2:
print("Phase 2: Play maze.")
print(f'Motorcade trajectory: {self.car_traj}')
# input('2222')