2025-03-29 21:28:39 +08:00
|
|
|
|
import gymnasium as gym
|
|
|
|
|
from gymnasium import spaces
|
|
|
|
|
import numpy as np
|
|
|
|
|
import yaml
|
|
|
|
|
import math
|
|
|
|
|
from mTSP_solver import mTSP
|
|
|
|
|
from GA.ga import GA
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PartitionEnv(gym.Env):
|
|
|
|
|
"""
|
|
|
|
|
自定义环境,分为两阶段:
|
|
|
|
|
区域切分,每一次切分都是(0, 1)之间的连续值
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, config=None):
|
|
|
|
|
super(PartitionEnv, self).__init__()
|
|
|
|
|
##############################
|
|
|
|
|
# 可能需要手动修改的超参数
|
|
|
|
|
##############################
|
2025-04-01 17:46:23 +08:00
|
|
|
|
self.params = 'params_50_50_3'
|
2025-03-29 21:28:39 +08:00
|
|
|
|
self.ORI_ROW_CUTS = [0, 0.2, 0.4, 0.7, 1]
|
|
|
|
|
self.ORI_COL_CUTS = [0, 0.5, 1]
|
|
|
|
|
self.CUT_NUM = 4
|
2025-04-01 17:46:23 +08:00
|
|
|
|
self.BASE_LINE = 9051.16
|
2025-03-29 21:28:39 +08:00
|
|
|
|
self.MAX_ADJUST_STEP = 50
|
|
|
|
|
self.ADJUST_THRESHOLD = 0.1
|
|
|
|
|
# self.mTSP_STEPS = 10000
|
|
|
|
|
|
|
|
|
|
# 切分位置+/-0.01
|
|
|
|
|
self.action_space = spaces.Discrete(self.CUT_NUM*2 + 1)
|
|
|
|
|
# 定义观察空间为8维向量
|
|
|
|
|
self.observation_space = spaces.Box(
|
|
|
|
|
low=0.0, high=1.0, shape=(self.CUT_NUM + 4,), dtype=np.float32)
|
|
|
|
|
|
|
|
|
|
self.row_cuts = self.ORI_ROW_CUTS[:]
|
|
|
|
|
self.col_cuts = self.ORI_COL_CUTS[:]
|
|
|
|
|
self.rectangles = []
|
|
|
|
|
self.adjust_step = 0
|
2025-03-31 11:12:01 +08:00
|
|
|
|
self.best_path = None
|
2025-03-29 21:28:39 +08:00
|
|
|
|
|
|
|
|
|
# 车队参数设置
|
|
|
|
|
with open(self.params + '.yml', 'r', encoding='utf-8') as file:
|
|
|
|
|
params = yaml.safe_load(file)
|
|
|
|
|
|
|
|
|
|
self.H = params['H']
|
|
|
|
|
self.W = params['W']
|
|
|
|
|
self.center = (self.H/2, self.W/2)
|
|
|
|
|
self.num_cars = params['num_cars']
|
|
|
|
|
|
|
|
|
|
self.flight_time_factor = params['flight_time_factor']
|
|
|
|
|
self.comp_time_factor = params['comp_time_factor']
|
|
|
|
|
self.trans_time_factor = params['trans_time_factor']
|
|
|
|
|
self.car_time_factor = params['car_time_factor']
|
|
|
|
|
self.bs_time_factor = params['bs_time_factor']
|
|
|
|
|
|
|
|
|
|
self.flight_energy_factor = params['flight_energy_factor']
|
|
|
|
|
self.comp_energy_factor = params['comp_energy_factor']
|
|
|
|
|
self.trans_energy_factor = params['trans_energy_factor']
|
|
|
|
|
self.battery_energy_capacity = params['battery_energy_capacity']
|
|
|
|
|
|
|
|
|
|
def reset(self, seed=None, options=None):
|
|
|
|
|
# 重置所有变量,回到切分阶段(phase 0)
|
|
|
|
|
self.row_cuts = self.ORI_ROW_CUTS[:]
|
|
|
|
|
self.col_cuts = self.ORI_COL_CUTS[:]
|
|
|
|
|
self.rectangles = []
|
|
|
|
|
self.adjust_step = 0
|
2025-03-31 11:12:01 +08:00
|
|
|
|
self.best_path = None
|
2025-03-29 21:28:39 +08:00
|
|
|
|
|
|
|
|
|
# 状态:前 4 维为 partition_values,其余为区域访问状态(初始全0)
|
|
|
|
|
state = np.array(self.row_cuts + self.col_cuts)
|
|
|
|
|
|
|
|
|
|
return state
|
|
|
|
|
|
|
|
|
|
def step(self, action):
|
|
|
|
|
if action == 1:
|
|
|
|
|
self.row_cuts[1] += 0.01
|
|
|
|
|
elif action == 2:
|
|
|
|
|
self.row_cuts[1] -= 0.01
|
|
|
|
|
elif action == 3:
|
|
|
|
|
self.row_cuts[2] += 0.01
|
|
|
|
|
elif action == 4:
|
|
|
|
|
self.row_cuts[2] -= 0.01
|
|
|
|
|
elif action == 5:
|
|
|
|
|
self.row_cuts[3] += 0.01
|
|
|
|
|
elif action == 6:
|
|
|
|
|
self.row_cuts[3] -= 0.01
|
|
|
|
|
elif action == 7:
|
|
|
|
|
self.col_cuts[1] += 0.01
|
|
|
|
|
elif action == 8:
|
|
|
|
|
self.col_cuts[1] -= 0.01
|
|
|
|
|
elif action == 9:
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
if self.row_cuts[0] < self.row_cuts[1] < self.row_cuts[2] < self.row_cuts[3] < self.row_cuts[4] and self.col_cuts[0] < self.col_cuts[1] < self.col_cuts[2]:
|
2025-03-31 11:12:01 +08:00
|
|
|
|
# 调整是合法的,验证分区情况是否满足条件
|
2025-03-29 21:28:39 +08:00
|
|
|
|
rectangles = self.if_valid_partition()
|
|
|
|
|
|
|
|
|
|
if not rectangles:
|
2025-03-31 11:12:01 +08:00
|
|
|
|
# 不满足条件,时间给一个很大的值
|
|
|
|
|
best_time = self.BASE_LINE * 2
|
2025-03-29 21:28:39 +08:00
|
|
|
|
else:
|
|
|
|
|
# 满足条件,继续进行路径规划
|
|
|
|
|
# 每隔10步计算一次路径,第一次也需要计算路径,记录最佳路径
|
2025-03-31 11:12:01 +08:00
|
|
|
|
if self.adjust_step % 10 == 0 or self.adjust_step == 1 or self.best_path is None:
|
2025-03-29 21:28:39 +08:00
|
|
|
|
best_time, self.best_path = self.ga_solver(rectangles)
|
|
|
|
|
else:
|
|
|
|
|
# 根据最佳路径计算当前时间
|
|
|
|
|
best_time = self.get_best_time(self.best_path, rectangles)
|
|
|
|
|
|
2025-03-31 11:12:01 +08:00
|
|
|
|
else:
|
|
|
|
|
# 调整不合法,时间给一个很大的值
|
|
|
|
|
best_time = self.BASE_LINE * 2
|
2025-03-29 21:28:39 +08:00
|
|
|
|
|
2025-03-31 11:12:01 +08:00
|
|
|
|
reward = self.calc_reward(best_time)
|
|
|
|
|
self.adjust_step += 1
|
|
|
|
|
state = np.array(self.row_cuts + self.col_cuts)
|
2025-04-01 17:46:23 +08:00
|
|
|
|
info = {'row_cuts': self.row_cuts, 'col_cuts': self.col_cuts,
|
|
|
|
|
'best_path': self.best_path, 'best_time': best_time}
|
2025-03-31 11:12:01 +08:00
|
|
|
|
|
|
|
|
|
if self.adjust_step < self.MAX_ADJUST_STEP:
|
2025-04-01 17:46:23 +08:00
|
|
|
|
return state, reward, False, False, info
|
2025-03-29 21:28:39 +08:00
|
|
|
|
else:
|
2025-04-01 17:46:23 +08:00
|
|
|
|
return state, reward, True, False, info
|
2025-03-29 21:28:39 +08:00
|
|
|
|
|
|
|
|
|
def if_valid_partition(self):
|
|
|
|
|
rectangles = []
|
|
|
|
|
for i in range(len(self.row_cuts) - 1):
|
|
|
|
|
for j in range(len(self.col_cuts) - 1):
|
|
|
|
|
d = (self.col_cuts[j+1] - self.col_cuts[j]) * self.W * \
|
|
|
|
|
(self.row_cuts[i+1] -
|
|
|
|
|
self.row_cuts[i]) * self.H
|
|
|
|
|
rho_time_limit = (self.flight_time_factor - self.trans_time_factor) / \
|
|
|
|
|
(self.comp_time_factor - self.trans_time_factor)
|
|
|
|
|
rho_energy_limit = (self.battery_energy_capacity - self.flight_energy_factor * d - self.trans_energy_factor * d) / \
|
|
|
|
|
(self.comp_energy_factor * d -
|
|
|
|
|
self.trans_energy_factor * d)
|
|
|
|
|
if rho_energy_limit < 0:
|
|
|
|
|
return []
|
|
|
|
|
rho = min(rho_time_limit, rho_energy_limit)
|
|
|
|
|
|
|
|
|
|
flight_time = self.flight_time_factor * d
|
|
|
|
|
bs_time = self.bs_time_factor * (1 - rho) * d
|
|
|
|
|
|
|
|
|
|
rectangles.append({
|
|
|
|
|
'center': ((self.row_cuts[i] + self.row_cuts[i+1]) * self.H / 2, (self.col_cuts[j+1] + self.col_cuts[j]) * self.W / 2),
|
|
|
|
|
'flight_time': flight_time,
|
|
|
|
|
'bs_time': bs_time,
|
|
|
|
|
})
|
|
|
|
|
return rectangles
|
|
|
|
|
|
|
|
|
|
def check_adjustment_threshold(self, threshold=0.1):
|
|
|
|
|
"""
|
|
|
|
|
检查当前切分位置与原始切分位置的差异是否超过阈值
|
|
|
|
|
Args:
|
|
|
|
|
threshold (float): 允许的最大调整幅度
|
|
|
|
|
Returns:
|
|
|
|
|
bool: 如果任何切分位置的调整超过阈值,返回True
|
|
|
|
|
"""
|
|
|
|
|
# 检查行切分位置
|
|
|
|
|
delta = 0
|
|
|
|
|
for i in range(len(self.row_cuts)):
|
|
|
|
|
delta += abs(self.row_cuts[i] - self.ORI_ROW_CUTS[i])
|
|
|
|
|
|
|
|
|
|
# 检查列切分位置
|
|
|
|
|
for i in range(len(self.col_cuts)):
|
|
|
|
|
delta += abs(self.col_cuts[i] - self.ORI_COL_CUTS[i])
|
|
|
|
|
|
|
|
|
|
if delta > threshold:
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
# def q_learning_solver(self):
|
|
|
|
|
# 使用q_learning解多旅行商
|
|
|
|
|
# cities: [[x1, x2, x3...], [y1, y2, y3...]] 城市坐标
|
|
|
|
|
# rec_center_lt = [rec_info['center']
|
|
|
|
|
# for rec_info in rectangles]
|
|
|
|
|
# cities = np.column_stack(rec_center_lt)
|
|
|
|
|
# cities = np.column_stack((self.center, cities))
|
|
|
|
|
|
|
|
|
|
# center_idx = []
|
|
|
|
|
# for i in range(self.num_cars - 1):
|
|
|
|
|
# cities = np.column_stack((cities, self.center))
|
|
|
|
|
# center_idx.append(cities.shape[1] - 1)
|
|
|
|
|
|
|
|
|
|
# tsp = mTSP(params=self.params, num_cities=cities.shape[1], cities=cities, num_cars=self.num_cars,
|
|
|
|
|
# center_idx=center_idx, rectangles=rectangles)
|
|
|
|
|
|
|
|
|
|
# best_time, best_path = tsp.train(self.mTSP_STEPS)
|
|
|
|
|
|
|
|
|
|
def ga_solver(self, rectangles):
|
|
|
|
|
cities = [self.center]
|
|
|
|
|
for rec in rectangles:
|
|
|
|
|
cities.append(rec['center'])
|
|
|
|
|
cities = np.array(cities)
|
|
|
|
|
|
|
|
|
|
center_idx = [0]
|
|
|
|
|
for i in range(self.num_cars - 1):
|
|
|
|
|
cities = np.row_stack((cities, self.center))
|
|
|
|
|
center_idx.append(cities.shape[0] - 1)
|
|
|
|
|
|
|
|
|
|
ga = GA(num_drones=self.num_cars, num_city=cities.shape[0], num_total=20,
|
|
|
|
|
data=cities, to_process_idx=center_idx, rectangles=rectangles)
|
|
|
|
|
best_path, best_time = ga.run()
|
|
|
|
|
return best_time, best_path
|
|
|
|
|
|
|
|
|
|
def get_best_time(self, best_path, rectangles):
|
|
|
|
|
cities = [self.center]
|
|
|
|
|
for rec in rectangles:
|
|
|
|
|
cities.append(rec['center'])
|
|
|
|
|
cities = np.array(cities)
|
|
|
|
|
|
|
|
|
|
center_idx = [0]
|
|
|
|
|
for i in range(self.num_cars - 1):
|
|
|
|
|
cities = np.row_stack((cities, self.center))
|
|
|
|
|
center_idx.append(cities.shape[0] - 1)
|
|
|
|
|
|
|
|
|
|
ga = GA(num_drones=self.num_cars, num_city=cities.shape[0], num_total=20,
|
|
|
|
|
data=cities, to_process_idx=center_idx, rectangles=rectangles)
|
|
|
|
|
best_time = ga.compute_pathlen(best_path)
|
|
|
|
|
return best_time
|
|
|
|
|
|
2025-03-31 11:12:01 +08:00
|
|
|
|
def calc_reward(self, best_time):
|
|
|
|
|
"""
|
2025-04-01 17:46:23 +08:00
|
|
|
|
计算奖励:
|
|
|
|
|
1. 如果时间小于基准线,给予正奖励
|
|
|
|
|
2. 如果时间大于基准线,给予负奖励
|
|
|
|
|
3. 保持归一化和折扣因子
|
|
|
|
|
|
2025-03-31 11:12:01 +08:00
|
|
|
|
Args:
|
|
|
|
|
best_time (float): 当前路径的时间
|
|
|
|
|
Returns:
|
|
|
|
|
float: 计算得到的奖励值
|
|
|
|
|
"""
|
|
|
|
|
time_diff = self.BASE_LINE - best_time
|
|
|
|
|
|
2025-04-01 17:46:23 +08:00
|
|
|
|
# 使用tanh归一化,确保time_diff=0时,normalized_diff=0
|
|
|
|
|
# tanh在变量值为2时,就非常接近1了。最大的time_diff为400
|
|
|
|
|
normalized_diff = np.tanh(time_diff / 200) # 20是缩放因子,可调整
|
2025-03-31 11:12:01 +08:00
|
|
|
|
|
2025-04-01 17:46:23 +08:00
|
|
|
|
# 计算轮次权重(折扣因子)
|
2025-03-31 11:12:01 +08:00
|
|
|
|
step_weight = 1 / (1 + np.exp(-self.adjust_step/10))
|
|
|
|
|
|
2025-04-01 17:46:23 +08:00
|
|
|
|
# 计算最终奖励
|
|
|
|
|
reward = normalized_diff * step_weight # 10是缩放因子
|
2025-03-31 11:12:01 +08:00
|
|
|
|
|
|
|
|
|
return reward
|
|
|
|
|
|
2025-03-29 21:28:39 +08:00
|
|
|
|
def render(self):
|
|
|
|
|
if self.phase == 1:
|
|
|
|
|
print("Phase 1: Initialize maze environment.")
|
|
|
|
|
print(f"Partition values so far: {self.partition_values}")
|
|
|
|
|
print(f"Motorcade positon: {self.car_pos}")
|
|
|
|
|
# input('1111')
|
|
|
|
|
elif self.phase == 2:
|
|
|
|
|
print("Phase 2: Play maze.")
|
|
|
|
|
print(f'Motorcade trajectory: {self.car_traj}')
|
|
|
|
|
# input('2222')
|