HPCC2025/env_partion.py
weixin_46229132 84f69f4293 离散情况
2025-03-29 21:28:39 +08:00

204 lines
7.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import gymnasium as gym
from gymnasium import spaces
import numpy as np
import yaml
import math
from mTSP_solver import mTSP
from GA.ga import GA
class PartitionEnv(gym.Env):
"""
自定义环境,分为两阶段:
区域切分,每一次切分都是(0, 1)之间的连续值
"""
def __init__(self, config=None):
super(PartitionEnv, self).__init__()
##############################
# 可能需要手动修改的超参数
##############################
self.params = 'params2'
self.ORI_ROW_CUTS = [0, 0.2, 0.4, 0.7, 1]
self.ORI_COL_CUTS = [0, 0.5, 1]
self.CUT_NUM = 4
self.ROW_CUT_LIMIT = 3
self.COL_CUT_LIMIT = 1
self.BASE_LINE = 10000
self.mTSP_STEPS = 10000
# 定义动作空间:全部动作均为 1 维连续 [0,1]
self.action_space = spaces.Box(
low=-0.1, high=0.1, shape=(1,), dtype=np.float32)
# 定义观察空间为8维向量
# 前 4 维表示已决策的切分值(未决策部分为 0
self.observation_space = spaces.Box(
low=0.0, high=1.0, shape=(self.CUT_NUM + 4,), dtype=np.float32)
self.partition_step = 0
self.ori_row_cuts = self.ORI_ROW_CUTS[:]
self.ori_col_cuts = self.ORI_COL_CUTS[:]
self.rectangles = []
# 车队参数设置
with open(self.params + '.yml', 'r', encoding='utf-8') as file:
params = yaml.safe_load(file)
self.H = params['H']
self.W = params['W']
self.center = (self.H/2, self.W/2)
self.num_cars = params['num_cars']
self.flight_time_factor = params['flight_time_factor']
self.comp_time_factor = params['comp_time_factor']
self.trans_time_factor = params['trans_time_factor']
self.car_time_factor = params['car_time_factor']
self.bs_time_factor = params['bs_time_factor']
self.flight_energy_factor = params['flight_energy_factor']
self.comp_energy_factor = params['comp_energy_factor']
self.trans_energy_factor = params['trans_energy_factor']
self.battery_energy_capacity = params['battery_energy_capacity']
def reset(self, seed=None, options=None):
# 重置所有变量回到切分阶段phase 0
self.phase = 0
self.partition_step = 0
self.ori_row_cuts = self.ORI_ROW_CUTS[:]
self.ori_col_cuts = self.ORI_COL_CUTS[:]
self.rectangles = []
# 状态:前 4 维为 partition_values其余为区域访问状态初始全0
state = np.array(self.ori_row_cuts + self.ori_col_cuts)
return state
def step(self, action):
# 在所有阶段动作均为 1 维连续动作,取 action[0]
adjust = float(action[0])
valid_adjust = True
if self.partition_step < self.ROW_CUT_LIMIT:
row_cut = self.ori_row_cuts[self.partition_step + 1]
new_row_cut = row_cut + adjust
self.ori_row_cuts[self.partition_step + 1] = new_row_cut
if self.ori_row_cuts[self.partition_step] < new_row_cut < self.ori_row_cuts[self.partition_step + 2]:
pass
else:
valid_adjust = False
reward = -100
else:
col_idx = self.partition_step - self.ROW_CUT_LIMIT
col_cut = self.ori_col_cuts[col_idx + 1]
new_col_cut = col_cut + adjust
self.ori_col_cuts[col_idx + 1] = new_col_cut
if self.ori_col_cuts[col_idx] < new_col_cut < self.ori_col_cuts[col_idx + 2]:
pass
else:
valid_adjust = False
reward = -100
self.partition_step += 1
state = np.array(self.ori_row_cuts + self.ori_col_cuts)
# 出现无效调整,直接结束
if not valid_adjust:
return state, reward, True, False, {}
# 调整合理,计算当前时间
else:
rectangles = self.if_valid_partition()
if not rectangles:
reward = -10
return state, reward, True, False, {}
else:
# 继续进行路径规划
# 使用遗传算法解多旅行商
best_time, best_path = self.ga_solver(rectangles)
# print(best_time)
# print(best_path)
reward = self.BASE_LINE / best_time
if self.partition_step < self.CUT_NUM:
done = False
else:
done = True
reward = reward * 3
return state, reward, done, False, best_path
def if_valid_partition(self):
rectangles = []
for i in range(len(self.ori_row_cuts) - 1):
for j in range(len(self.ori_col_cuts) - 1):
d = (self.ori_col_cuts[j+1] - self.ori_col_cuts[j]) * self.W * \
(self.ori_row_cuts[i+1] -
self.ori_row_cuts[i]) * self.H
rho_time_limit = (self.flight_time_factor - self.trans_time_factor) / \
(self.comp_time_factor - self.trans_time_factor)
rho_energy_limit = (self.battery_energy_capacity - self.flight_energy_factor * d - self.trans_energy_factor * d) / \
(self.comp_energy_factor * d -
self.trans_energy_factor * d)
if rho_energy_limit < 0:
return []
rho = min(rho_time_limit, rho_energy_limit)
flight_time = self.flight_time_factor * d
bs_time = self.bs_time_factor * (1 - rho) * d
rectangles.append({
'center': ((self.ori_row_cuts[i] + self.ori_row_cuts[i+1]) * self.H / 2, (self.ori_col_cuts[j+1] + self.ori_col_cuts[j]) * self.W / 2),
'flight_time': flight_time,
'bs_time': bs_time,
})
return rectangles
# def q_learning_solver(self):
# 使用q_learning解多旅行商
# cities: [[x1, x2, x3...], [y1, y2, y3...]] 城市坐标
# rec_center_lt = [rec_info['center']
# for rec_info in rectangles]
# cities = np.column_stack(rec_center_lt)
# cities = np.column_stack((self.center, cities))
# center_idx = []
# for i in range(self.num_cars - 1):
# cities = np.column_stack((cities, self.center))
# center_idx.append(cities.shape[1] - 1)
# tsp = mTSP(params=self.params, num_cities=cities.shape[1], cities=cities, num_cars=self.num_cars,
# center_idx=center_idx, rectangles=rectangles)
# best_time, best_path = tsp.train(self.mTSP_STEPS)
def ga_solver(self, rectangles):
cities = [self.center]
for rec in rectangles:
cities.append(rec['center'])
cities = np.array(cities)
center_idx = [0]
for i in range(self.num_cars - 1):
cities = np.row_stack((cities, self.center))
center_idx.append(cities.shape[0] - 1)
ga = GA(num_drones=self.num_cars, num_city=cities.shape[0], num_total=20,
data=cities, to_process_idx=center_idx, rectangles=rectangles)
best_path, best_time = ga.run()
return best_time, best_path
def render(self):
if self.phase == 1:
print("Phase 1: Initialize maze environment.")
print(f"Partition values so far: {self.partition_values}")
print(f"Motorcade positon: {self.car_pos}")
# input('1111')
elif self.phase == 2:
print("Phase 2: Play maze.")
print(f'Motorcade trajectory: {self.car_traj}')
# input('2222')