HPCC2025/env_partion_dist.py
2025-04-12 22:55:01 +08:00

276 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import gymnasium as gym
from gymnasium import spaces
import numpy as np
import yaml
import math
from mTSP_solver import mTSP
from GA.ga import GA
class PartitionEnv(gym.Env):
"""
自定义环境,分为两阶段:
区域切分,每一次切分都是(0, 1)之间的连续值
"""
def __init__(self, config=None):
super(PartitionEnv, self).__init__()
##############################
# 可能需要手动修改的超参数
##############################
self.params = 'params_100_100_6'
self.ORI_ROW_CUTS = [0, 0.28, 0.43, 0.62, 0.77, 1]
self.ORI_COL_CUTS = [0, 0.2, 0.4, 0.5, 0.7, 0.8, 1]
self.CUT_NUM = 5
self.BASE_LINE = 19376.06
self.MAX_ADJUST_STEP = 50
# self.ADJUST_THRESHOLD = 0.1
# self.mTSP_STEPS = 10000
# 切分位置+/-0.01
self.action_space = spaces.Discrete(self.CUT_NUM*2 + 1)
# 定义观察空间为8维向量
self.observation_space = spaces.Box(
low=0.0, high=1.0, shape=(len(self.ORI_ROW_CUTS)+len(self.ORI_COL_CUTS),), dtype=np.float32)
self.row_cuts = self.ORI_ROW_CUTS[:]
self.col_cuts = self.ORI_COL_CUTS[:]
self.rectangles = []
self.adjust_step = 0
self.best_path = None
# 车队参数设置
with open(self.params + '.yml', 'r', encoding='utf-8') as file:
params = yaml.safe_load(file)
self.H = params['H']
self.W = params['W']
self.center = (self.H/2, self.W/2)
self.num_cars = params['num_cars']
self.flight_time_factor = params['flight_time_factor']
self.comp_time_factor = params['comp_time_factor']
self.trans_time_factor = params['trans_time_factor']
self.car_time_factor = params['car_time_factor']
self.bs_time_factor = params['bs_time_factor']
self.flight_energy_factor = params['flight_energy_factor']
self.comp_energy_factor = params['comp_energy_factor']
self.trans_energy_factor = params['trans_energy_factor']
self.battery_energy_capacity = params['battery_energy_capacity']
def reset(self, seed=None, options=None):
# 重置所有变量回到切分阶段phase 0
self.row_cuts = self.ORI_ROW_CUTS[:]
self.col_cuts = self.ORI_COL_CUTS[:]
self.rectangles = []
self.adjust_step = 0
self.best_path = None
# 状态:前 4 维为 partition_values其余为区域访问状态初始全0
state = np.array(self.row_cuts + self.col_cuts)
return state
def step(self, action):
if action == 1:
self.row_cuts[1] += 0.01
elif action == 2:
self.row_cuts[1] -= 0.01
elif action == 3:
self.row_cuts[2] += 0.01
elif action == 4:
self.row_cuts[2] -= 0.01
elif action == 5:
self.row_cuts[3] += 0.01
elif action == 6:
self.row_cuts[3] -= 0.01
elif action == 7:
self.col_cuts[1] += 0.01
elif action == 8:
self.col_cuts[1] -= 0.01
elif action == 0:
pass
# cut_index, signal = (action + 1) // 2, (action + 1) % 2
# if action == 0:
# pass
# elif cut_index <= 5:
# if signal == 0:
# self.col_cuts[cut_index] += 0.005
# else:
# self.col_cuts[cut_index] -= 0.005
# else:
# if signal == 0:
# self.col_cuts[cut_index-4] += 0.005
# else:
# self.col_cuts[cut_index-4] -= 0.005
# 检查row_cuts和col_cuts是否按升序排列
if (all(self.row_cuts[i] < self.row_cuts[i+1] for i in range(len(self.row_cuts)-1)) and
all(self.col_cuts[i] < self.col_cuts[i+1] for i in range(len(self.col_cuts)-1))):
# 调整是合法的,验证分区情况是否满足条件
rectangles = self.if_valid_partition()
if not rectangles:
# 不满足条件,时间给一个很大的值
best_time = self.BASE_LINE * 2
else:
# 满足条件,继续进行路径规划
# 每隔10步计算一次路径第一次也需要计算路径记录最佳路径
if self.adjust_step % 10 == 0 or self.best_path is None:
best_time, self.best_path = self.ga_solver(rectangles)
else:
# 根据最佳路径计算当前时间
best_time = self.get_best_time(self.best_path, rectangles)
# self.best_path = [33, 30, 29, 28, 27, 21, 15, 0, 13, 7, 1, 2, 31, 14, 8, 3, 4,
# 10, 32, 23, 22, 24, 18, 17, 16, 35, 9, 12, 6, 5, 11, 34, 20, 25, 26, 19, 0]
# best_time = self.get_best_time(self.best_path, rectangles)
else:
# 调整不合法,时间给一个很大的值
best_time = self.BASE_LINE * 2
reward = self.calc_reward(best_time)
self.adjust_step += 1
state = np.array(self.row_cuts + self.col_cuts)
info = {'row_cuts': self.row_cuts, 'col_cuts': self.col_cuts,
'best_path': self.best_path, 'best_time': best_time}
if self.adjust_step < self.MAX_ADJUST_STEP:
return state, reward, False, False, info
else:
return state, reward, True, False, info
def if_valid_partition(self):
rectangles = []
for i in range(len(self.row_cuts) - 1):
for j in range(len(self.col_cuts) - 1):
d = (self.col_cuts[j+1] - self.col_cuts[j]) * self.W * \
(self.row_cuts[i+1] -
self.row_cuts[i]) * self.H
rho_time_limit = (self.flight_time_factor - self.trans_time_factor) / \
(self.comp_time_factor - self.trans_time_factor)
rho_energy_limit = (self.battery_energy_capacity - self.flight_energy_factor * d - self.trans_energy_factor * d) / \
(self.comp_energy_factor * d -
self.trans_energy_factor * d)
if rho_energy_limit < 0:
return []
rho = min(rho_time_limit, rho_energy_limit)
flight_time = self.flight_time_factor * d
bs_time = self.bs_time_factor * (1 - rho) * d
rectangles.append({
'center': ((self.row_cuts[i] + self.row_cuts[i+1]) * self.H / 2, (self.col_cuts[j+1] + self.col_cuts[j]) * self.W / 2),
'flight_time': flight_time,
'bs_time': bs_time,
})
return rectangles
def check_adjustment_threshold(self, threshold=0.1):
"""
检查当前切分位置与原始切分位置的差异是否超过阈值
Args:
threshold (float): 允许的最大调整幅度
Returns:
bool: 如果任何切分位置的调整超过阈值返回True
"""
# 检查行切分位置
delta = 0
for i in range(len(self.row_cuts)):
delta += abs(self.row_cuts[i] - self.ORI_ROW_CUTS[i])
# 检查列切分位置
for i in range(len(self.col_cuts)):
delta += abs(self.col_cuts[i] - self.ORI_COL_CUTS[i])
if delta > threshold:
return True
return False
# def q_learning_solver(self):
# 使用q_learning解多旅行商
# cities: [[x1, x2, x3...], [y1, y2, y3...]] 城市坐标
# rec_center_lt = [rec_info['center']
# for rec_info in rectangles]
# cities = np.column_stack(rec_center_lt)
# cities = np.column_stack((self.center, cities))
# center_idx = []
# for i in range(self.num_cars - 1):
# cities = np.column_stack((cities, self.center))
# center_idx.append(cities.shape[1] - 1)
# tsp = mTSP(params=self.params, num_cities=cities.shape[1], cities=cities, num_cars=self.num_cars,
# center_idx=center_idx, rectangles=rectangles)
# best_time, best_path = tsp.train(self.mTSP_STEPS)
def ga_solver(self, rectangles):
cities = [self.center]
for rec in rectangles:
cities.append(rec['center'])
cities = np.array(cities)
center_idx = [0]
for i in range(self.num_cars - 1):
cities = np.row_stack((cities, self.center))
center_idx.append(cities.shape[0] - 1)
ga = GA(num_drones=self.num_cars, num_city=cities.shape[0], num_total=20,
data=cities, to_process_idx=center_idx, rectangles=rectangles)
best_path, best_time = ga.run()
return best_time, best_path
def get_best_time(self, best_path, rectangles):
cities = [self.center]
for rec in rectangles:
cities.append(rec['center'])
cities = np.array(cities)
center_idx = [0]
for i in range(self.num_cars - 1):
cities = np.row_stack((cities, self.center))
center_idx.append(cities.shape[0] - 1)
ga = GA(num_drones=self.num_cars, num_city=cities.shape[0], num_total=20,
data=cities, to_process_idx=center_idx, rectangles=rectangles)
best_time = ga.compute_pathlen(best_path)
return best_time
def calc_reward(self, best_time):
"""
计算奖励:
1. 如果时间小于基准线,给予正奖励
2. 如果时间大于基准线,给予负奖励
3. 保持归一化和折扣因子
Args:
best_time (float): 当前路径的时间
Returns:
float: 计算得到的奖励值
"""
time_diff = self.BASE_LINE - best_time
# 使用tanh归一化确保time_diff=0时normalized_diff=0
# tanh在变量值为2时就非常接近1了。最大的time_diff为400
normalized_diff = np.tanh(time_diff / 5000) # 20是缩放因子可调整
# 计算最终奖励
reward = normalized_diff
# * step_weight # 10是缩放因子
return reward
def render(self):
if self.phase == 1:
print("Phase 1: Initialize maze environment.")
print(f"Partition values so far: {self.partition_values}")
print(f"Motorcade positon: {self.car_pos}")
# input('1111')
elif self.phase == 2:
print("Phase 2: Play maze.")
print(f'Motorcade trajectory: {self.car_traj}')
# input('2222')