HPCC2025/env_partion.py
weixin_46229132 f347ca8276 微调分区
2025-03-29 16:28:30 +08:00

198 lines
8.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import gymnasium as gym
from gymnasium import spaces
import numpy as np
import yaml
import math
from mTSP_solver import mTSP
from GA.ga import GA
class PartitionEnv(gym.Env):
"""
自定义环境,分为两阶段:
区域切分,每一次切分都是(0, 1)之间的连续值
"""
def __init__(self, config=None):
super(PartitionEnv, self).__init__()
##############################
# 可能需要手动修改的超参数
##############################
self.params = 'params2'
self.ORI_ROW_CUTS = [0, 0.2, 0.4, 0.7, 1]
self.ORI_COL_CUTS = [0, 0.5, 1]
self.CUT_NUM = 4
self.ROW_CUT_LIMIT = 3
self.COL_CUT_LIMIT = 1
self.BASE_LINE = 10000
self.mTSP_STEPS = 10000
# 定义动作空间:全部动作均为 1 维连续 [0,1]
self.action_space = spaces.Box(
low=0.0, high=1.0, shape=(1,), dtype=np.float32)
# 定义观察空间为8维向量
# 前 4 维表示已决策的切分值(未决策部分为 0
self.observation_space = spaces.Box(
low=0.0, high=1.0, shape=(self.CUT_NUM + 4,), dtype=np.float32)
self.partition_step = 0
self.ori_row_cuts = self.ORI_ROW_CUTS[:]
self.ori_col_cuts = self.ORI_COL_CUTS[:]
self.rectangles = []
# 车队参数设置
with open(self.params + '.yml', 'r', encoding='utf-8') as file:
params = yaml.safe_load(file)
self.H = params['H']
self.W = params['W']
self.center = (self.H/2, self.W/2)
self.num_cars = params['num_cars']
self.flight_time_factor = params['flight_time_factor']
self.comp_time_factor = params['comp_time_factor']
self.trans_time_factor = params['trans_time_factor']
self.car_time_factor = params['car_time_factor']
self.bs_time_factor = params['bs_time_factor']
self.flight_energy_factor = params['flight_energy_factor']
self.comp_energy_factor = params['comp_energy_factor']
self.trans_energy_factor = params['trans_energy_factor']
self.battery_energy_capacity = params['battery_energy_capacity']
def reset(self, seed=None, options=None):
# 重置所有变量回到切分阶段phase 0
self.phase = 0
self.partition_step = 0
self.ori_row_cuts = self.ORI_ROW_CUTS[:]
self.ori_col_cuts = self.ORI_COL_CUTS[:]
self.rectangles = []
# 状态:前 4 维为 partition_values其余为区域访问状态初始全0
state = np.array(self.ori_row_cuts + self.ori_col_cuts)
return state
def step(self, action):
# 在所有阶段动作均为 1 维连续动作,取 action[0]
adjust = float(action[0])
valid_adjust = True
if self.partition_step < self.ROW_CUT_LIMIT:
row_cut = self.ori_row_cuts[self.partition_step + 1]
new_row_cut = row_cut + adjust
self.ori_row_cuts[self.partition_step + 1] = new_row_cut
if self.ori_row_cuts[self.partition_step] < new_row_cut < self.ori_row_cuts[self.partition_step + 2]:
pass
else:
valid_adjust = False
reward = -100
else:
col_idx = self.partition_step - self.ROW_CUT_LIMIT
col_cut = self.ori_col_cuts[col_idx + 1]
new_col_cut = col_cut + adjust
self.ori_col_cuts[col_idx + 1] = new_col_cut
if self.ori_col_cuts[col_idx] < new_col_cut < self.ori_col_cuts[col_idx + 2]:
pass
else:
valid_adjust = False
reward = -100
self.partition_step += 1
state = np.array(self.ori_row_cuts + self.ori_col_cuts)
# 出现无效调整,直接结束
if not valid_adjust:
return state, reward, True, False, {}
else:
if self.partition_step < self.CUT_NUM:
return state, 0.0, False, False, {}
else:
# 完成 4 步后,判断分区是否合理,并计算各个分区的任务卸载率ρ
valid_partition = True
for i in range(len(self.ori_row_cuts) - 1):
for j in range(len(self.ori_col_cuts) - 1):
d = (self.ori_col_cuts[j+1] - self.ori_col_cuts[j]) * self.W * \
(self.ori_row_cuts[i+1] -
self.ori_row_cuts[i]) * self.H
rho_time_limit = (self.flight_time_factor - self.trans_time_factor) / \
(self.comp_time_factor - self.trans_time_factor)
rho_energy_limit = (self.battery_energy_capacity - self.flight_energy_factor * d - self.trans_energy_factor * d) / \
(self.comp_energy_factor * d -
self.trans_energy_factor * d)
if rho_energy_limit < 0:
valid_partition = False
break
rho = min(rho_time_limit, rho_energy_limit)
flight_time = self.flight_time_factor * d
bs_time = self.bs_time_factor * (1 - rho) * d
self.rectangles.append({
'center': ((self.ori_row_cuts[i] + self.ori_row_cuts[i+1]) * self.H / 2, (self.ori_col_cuts[j+1] + self.ori_col_cuts[j]) * self.W / 2),
'flight_time': flight_time,
'bs_time': bs_time,
})
if not valid_partition:
break
if not valid_partition:
reward = -10
return state, reward, True, False, {}
else:
# 继续进行路径规划
# 使用q_learning解多旅行商
# cities: [[x1, x2, x3...], [y1, y2, y3...]] 城市坐标
# rec_center_lt = [rec_info['center']
# for rec_info in self.rectangles]
# cities = np.column_stack(rec_center_lt)
# cities = np.column_stack((self.center, cities))
# center_idx = []
# for i in range(self.num_cars - 1):
# cities = np.column_stack((cities, self.center))
# center_idx.append(cities.shape[1] - 1)
# tsp = mTSP(params=self.params, num_cities=cities.shape[1], cities=cities, num_cars=self.num_cars,
# center_idx=center_idx, rectangles=self.rectangles)
# best_time, best_path = tsp.train(self.mTSP_STEPS)
# 使用遗传算法解多旅行商
cities = [self.center]
for rec in self.rectangles:
cities.append(rec['center'])
cities = np.array(cities)
center_idx = [0]
for i in range(self.num_cars - 1):
cities = np.row_stack((cities, self.center))
center_idx.append(cities.shape[0] - 1)
ga = GA(num_drones=self.num_cars, num_city=cities.shape[0], num_total=20,
data=cities, to_process_idx=center_idx, rectangles=self.rectangles)
best_path, best_time = ga.run()
# print(best_time)
# print(best_path)
reward = self.BASE_LINE / best_time
return state, reward, True, False, best_path
def render(self):
if self.phase == 1:
print("Phase 1: Initialize maze environment.")
print(f"Partition values so far: {self.partition_values}")
print(f"Motorcade positon: {self.car_pos}")
# input('1111')
elif self.phase == 2:
print("Phase 2: Play maze.")
print(f'Motorcade trajectory: {self.car_traj}')
# input('2222')