2025-03-27 20:50:46 +08:00
|
|
|
|
import random
|
|
|
|
|
import numpy as np
|
|
|
|
|
import json
|
|
|
|
|
import math
|
|
|
|
|
import yaml
|
2025-03-27 21:48:07 +08:00
|
|
|
|
from typing import Tuple, List, Dict, Any, Optional
|
|
|
|
|
from dataclasses import dataclass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class Config:
|
|
|
|
|
"""配置类,集中管理所有超参数"""
|
|
|
|
|
# 学习参数
|
|
|
|
|
ALPHA: float = 0.1 # 学习率
|
|
|
|
|
GAMMA: float = 0.9 # 折扣因子
|
|
|
|
|
EPSILON: float = 0.2 # 探索率
|
|
|
|
|
NUM_EPISODES: int = 100 # 训练回合数
|
|
|
|
|
|
|
|
|
|
# 状态空间参数
|
|
|
|
|
STEP: float = 0.01 # 状态变化步长
|
|
|
|
|
VALUES: List[float] = None # 0.00~1.00的离散值列表
|
|
|
|
|
|
|
|
|
|
# 动作空间参数
|
|
|
|
|
ACTION_DELTA: List[float] = None # 动作变化量
|
|
|
|
|
|
|
|
|
|
# 环境参数
|
|
|
|
|
MIN_IMPROVEMENT: float = 0.001 # 最小改善阈值
|
|
|
|
|
MAX_NO_IMPROVEMENT: int = 10 # 最大允许连续未改善次数
|
|
|
|
|
TARGET_THRESHOLD: float = 10000 # 目标函数值的可接受阈值
|
|
|
|
|
|
|
|
|
|
def __post_init__(self):
|
|
|
|
|
"""初始化依赖参数"""
|
|
|
|
|
self.VALUES = [round(i*self.STEP, 2) for i in range(101)]
|
|
|
|
|
self.ACTION_DELTA = [self.STEP, -self.STEP]
|
|
|
|
|
# TODO 需要修改4个变量
|
|
|
|
|
self.ACTIONS = [(i, delta) for i in range(4)
|
|
|
|
|
for delta in self.ACTION_DELTA]
|
|
|
|
|
|
2025-03-27 20:50:46 +08:00
|
|
|
|
|
|
|
|
|
class FunctionEnv:
|
2025-03-27 21:48:07 +08:00
|
|
|
|
"""环境类:定义状态转移与奖励"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, params_file: str, initial_state: Tuple[float, float, float], cut_index: int, car_paths: List[List[int]], config: Config):
|
|
|
|
|
self.params_file = params_file
|
|
|
|
|
with open(self.params_file + '.yml', 'r', encoding='utf-8') as file:
|
|
|
|
|
params = yaml.safe_load(file)
|
|
|
|
|
|
|
|
|
|
self.H = params['H']
|
|
|
|
|
self.W = params['W']
|
|
|
|
|
self.num_cars = params['num_cars']
|
|
|
|
|
|
|
|
|
|
self.flight_time_factor = params['flight_time_factor']
|
|
|
|
|
self.comp_time_factor = params['comp_time_factor']
|
|
|
|
|
self.trans_time_factor = params['trans_time_factor']
|
|
|
|
|
self.car_time_factor = params['car_time_factor']
|
|
|
|
|
self.bs_time_factor = params['bs_time_factor']
|
|
|
|
|
|
|
|
|
|
self.flight_energy_factor = params['flight_energy_factor']
|
|
|
|
|
self.comp_energy_factor = params['comp_energy_factor']
|
|
|
|
|
self.trans_energy_factor = params['trans_energy_factor']
|
|
|
|
|
self.battery_energy_capacity = params['battery_energy_capacity']
|
|
|
|
|
|
|
|
|
|
self.state = initial_state
|
|
|
|
|
self.cut_index = cut_index
|
|
|
|
|
self.config = config
|
|
|
|
|
self.car_paths = car_paths
|
|
|
|
|
self.best_value = float('inf')
|
|
|
|
|
self.no_improvement_count = 0
|
|
|
|
|
self.last_state = None
|
|
|
|
|
|
|
|
|
|
def step(self, action: Tuple[int, float]) -> Tuple[Tuple[float, float, float], float, bool]:
|
|
|
|
|
"""
|
|
|
|
|
执行一步动作
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
action: (变量索引, 变化量)的元组
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tuple[Tuple[float, float, float], float, bool]: (下一个状态, 奖励, 是否结束)
|
|
|
|
|
"""
|
2025-03-27 20:50:46 +08:00
|
|
|
|
var_index, delta = action
|
|
|
|
|
new_state = list(self.state)
|
|
|
|
|
new_state[var_index] = round(new_state[var_index] + delta, 2)
|
2025-03-27 21:48:07 +08:00
|
|
|
|
row_cuts_state = new_state[:self.cut_index]
|
|
|
|
|
col_cuts_state = new_state[self.cut_index:]
|
|
|
|
|
|
|
|
|
|
# 检查约束条件
|
|
|
|
|
if not self._is_valid_state(row_cuts_state) or not self._is_valid_state(col_cuts_state):
|
2025-03-27 20:50:46 +08:00
|
|
|
|
return self.state, -10000.0, True
|
2025-03-27 21:48:07 +08:00
|
|
|
|
|
|
|
|
|
current_value = self.calculate_objective(
|
|
|
|
|
row_cuts_state, col_cuts_state)
|
|
|
|
|
|
|
|
|
|
# 检查终止条件
|
|
|
|
|
if self._should_terminate(new_state, current_value):
|
|
|
|
|
return new_state, 12000 - current_value, True
|
|
|
|
|
|
|
|
|
|
self._update_state(new_state, current_value)
|
|
|
|
|
return new_state, 12000 - current_value, False
|
|
|
|
|
|
|
|
|
|
def _is_valid_state(self, state: List[float]) -> bool:
|
|
|
|
|
"""
|
|
|
|
|
检查状态是否满足约束条件
|
|
|
|
|
确保列表中的元素严格递增且在(0,1)范围内
|
2025-03-27 20:50:46 +08:00
|
|
|
|
|
2025-03-27 21:48:07 +08:00
|
|
|
|
Args:
|
|
|
|
|
state: 需要检查的状态列表
|
2025-03-27 20:50:46 +08:00
|
|
|
|
|
2025-03-27 21:48:07 +08:00
|
|
|
|
Returns:
|
|
|
|
|
bool: 是否满足约束条件
|
|
|
|
|
"""
|
|
|
|
|
# 检查列表是否为空
|
|
|
|
|
if not state:
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
# 检查所有元素是否在(0,1)范围内
|
|
|
|
|
if not all(0 < x < 1 for x in state):
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
# 检查是否严格递增
|
|
|
|
|
for i in range(len(state) - 1):
|
|
|
|
|
if state[i] >= state[i + 1]:
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
def _should_terminate(self, state: Tuple[float, float, float], value: float) -> bool:
|
|
|
|
|
"""检查是否应该终止"""
|
|
|
|
|
if value < self.config.TARGET_THRESHOLD:
|
|
|
|
|
return True
|
|
|
|
|
|
2025-03-27 20:50:46 +08:00
|
|
|
|
if self.last_state is not None:
|
2025-03-27 21:48:07 +08:00
|
|
|
|
state_diff = sum(abs(a - b)
|
|
|
|
|
for a, b in zip(state, self.last_state))
|
|
|
|
|
if state_diff < self.config.MIN_IMPROVEMENT:
|
2025-03-27 20:50:46 +08:00
|
|
|
|
self.no_improvement_count += 1
|
|
|
|
|
else:
|
|
|
|
|
self.no_improvement_count = 0
|
2025-03-27 21:48:07 +08:00
|
|
|
|
|
|
|
|
|
if value < self.best_value:
|
|
|
|
|
self.best_value = value
|
2025-03-27 20:50:46 +08:00
|
|
|
|
self.no_improvement_count = 0
|
|
|
|
|
else:
|
|
|
|
|
self.no_improvement_count += 1
|
2025-03-27 21:48:07 +08:00
|
|
|
|
|
|
|
|
|
return self.no_improvement_count >= self.config.MAX_NO_IMPROVEMENT
|
|
|
|
|
|
|
|
|
|
def _update_state(self, state: Tuple[float, float, float], value: float) -> None:
|
|
|
|
|
"""更新状态"""
|
|
|
|
|
self.last_state = self.state
|
|
|
|
|
self.state = state
|
|
|
|
|
|
|
|
|
|
def calculate_objective(self, row_cuts, col_cuts):
|
|
|
|
|
"""
|
|
|
|
|
计算切分比例的目标值 T(占位函数)
|
|
|
|
|
:param row_cuts: 行切分比例
|
|
|
|
|
:param col_cuts: 列切分比例
|
|
|
|
|
:return: 目标值 T
|
|
|
|
|
"""
|
|
|
|
|
row_cuts = [0] + row_cuts + [1]
|
|
|
|
|
col_cuts = [0] + col_cuts + [1]
|
|
|
|
|
rectangles = []
|
|
|
|
|
for i in range(len(row_cuts) - 1):
|
|
|
|
|
for j in range(len(col_cuts) - 1):
|
|
|
|
|
d = (col_cuts[j+1] - col_cuts[j]) * self.W * \
|
|
|
|
|
(row_cuts[i+1] - row_cuts[i]) * self.H
|
|
|
|
|
rho_time_limit = (self.flight_time_factor - self.trans_time_factor) / \
|
|
|
|
|
(self.comp_time_factor - self.trans_time_factor)
|
|
|
|
|
rho_energy_limit = (self.battery_energy_capacity - self.flight_energy_factor * d - self.trans_energy_factor * d) / \
|
|
|
|
|
(self.comp_energy_factor * d -
|
|
|
|
|
self.trans_energy_factor * d)
|
|
|
|
|
if rho_energy_limit < 0:
|
|
|
|
|
return 1000000
|
|
|
|
|
rho = min(rho_time_limit, rho_energy_limit)
|
|
|
|
|
|
|
|
|
|
flight_time = self.flight_time_factor * d
|
|
|
|
|
bs_time = self.bs_time_factor * (1 - rho) * d
|
|
|
|
|
|
|
|
|
|
rectangles.append({
|
|
|
|
|
'flight_time': flight_time,
|
|
|
|
|
'bs_time': bs_time,
|
|
|
|
|
'center': ((row_cuts[i] + row_cuts[i+1]) / 2.0 * self.H,
|
|
|
|
|
(col_cuts[j] + col_cuts[j+1]) / 2.0 * self.W)
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
mortorcade_time_lt = []
|
|
|
|
|
for idx in range(self.num_cars):
|
|
|
|
|
car_path = self.car_paths[idx]
|
|
|
|
|
|
|
|
|
|
flight_time = sum(rectangles[point]['flight_time']
|
|
|
|
|
for point in car_path)
|
|
|
|
|
bs_time = sum(rectangles[point]['bs_time'] for point in car_path)
|
|
|
|
|
|
|
|
|
|
car_time = 0
|
|
|
|
|
for i in range(len(car_path) - 1):
|
|
|
|
|
first_point = car_path[i]
|
|
|
|
|
second_point = car_path[i + 1]
|
|
|
|
|
car_time += math.dist(
|
|
|
|
|
rectangles[first_point]['center'], rectangles[second_point]['center']) * self.car_time_factor
|
|
|
|
|
car_time += math.dist(rectangles[car_path[0]]['center'],
|
|
|
|
|
[self.H / 2, self.W / 2]) * self.car_time_factor
|
|
|
|
|
car_time += math.dist(rectangles[car_path[-1]]['center'],
|
|
|
|
|
[self.H / 2, self.W / 2]) * self.car_time_factor
|
|
|
|
|
mortorcade_time_lt.append(max(car_time + flight_time, bs_time))
|
|
|
|
|
|
|
|
|
|
return max(mortorcade_time_lt)
|
|
|
|
|
|
|
|
|
|
def reset(self, state: Tuple[float, float, float]) -> Tuple[float, float, float]:
|
|
|
|
|
"""重置环境状态"""
|
2025-03-27 20:50:46 +08:00
|
|
|
|
self.state = state
|
|
|
|
|
return self.state
|
|
|
|
|
|
2025-03-27 21:48:07 +08:00
|
|
|
|
|
|
|
|
|
class QLearning:
|
|
|
|
|
"""Q-learning算法实现"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, config: Config):
|
|
|
|
|
self.config = config
|
|
|
|
|
self.Q_table: Dict[Tuple[float, float, float],
|
|
|
|
|
Dict[Tuple[int, float], float]] = {}
|
|
|
|
|
|
|
|
|
|
def get_Q(self, state: Tuple[float, float, float], action: Tuple[int, float]) -> float:
|
|
|
|
|
"""获取Q值"""
|
|
|
|
|
if state not in self.Q_table:
|
|
|
|
|
self.Q_table[state] = {a: 0.0 for a in self.config.ACTIONS}
|
|
|
|
|
return self.Q_table[state][action]
|
|
|
|
|
|
|
|
|
|
def set_Q(self, state: Tuple[float, float, float], action: Tuple[int, float], value: float) -> None:
|
|
|
|
|
"""设置Q值"""
|
|
|
|
|
if state not in self.Q_table:
|
|
|
|
|
self.Q_table[state] = {a: 0.0 for a in self.config.ACTIONS}
|
|
|
|
|
self.Q_table[state][action] = value
|
|
|
|
|
|
|
|
|
|
def choose_action(self, state: Tuple[float, float, float], epsilon: float) -> Tuple[int, float]:
|
|
|
|
|
"""选择动作(ε-greedy策略)"""
|
|
|
|
|
if random.random() < epsilon:
|
|
|
|
|
return random.choice(self.config.ACTIONS)
|
|
|
|
|
if state not in self.Q_table:
|
|
|
|
|
self.Q_table[state] = {a: 0.0 for a in self.config.ACTIONS}
|
|
|
|
|
return max(self.Q_table[state].items(), key=lambda x: x[1])[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_initial_solution(file_path: str) -> Tuple[List[float], List[float], List[List[int]]]:
|
2025-03-27 20:50:46 +08:00
|
|
|
|
"""
|
2025-03-27 21:48:07 +08:00
|
|
|
|
从JSON文件加载初始解
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
file_path: JSON文件路径
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tuple[List[float], List[float], List[List[int]]]: (行切分比例, 列切分比例, 车辆路径)
|
2025-03-27 20:50:46 +08:00
|
|
|
|
"""
|
|
|
|
|
with open(file_path, 'r', encoding='utf-8') as file:
|
|
|
|
|
data = json.load(file)
|
2025-03-27 21:48:07 +08:00
|
|
|
|
return data['row_boundaries'], data['col_boundaries'], data['car_paths']
|
2025-03-27 20:50:46 +08:00
|
|
|
|
|
2025-03-27 21:48:07 +08:00
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
|
"""主函数"""
|
|
|
|
|
# 初始化配置
|
|
|
|
|
config = Config()
|
2025-03-27 20:50:46 +08:00
|
|
|
|
random.seed(42)
|
|
|
|
|
|
2025-03-27 21:48:07 +08:00
|
|
|
|
# 加载初始解
|
2025-03-27 20:50:46 +08:00
|
|
|
|
solution_path = r"solutions\trav_ga_params2_parallel.json"
|
|
|
|
|
params_file = r"params2"
|
|
|
|
|
|
|
|
|
|
initial_row_cuts, initial_col_cuts, car_paths = load_initial_solution(
|
|
|
|
|
solution_path)
|
|
|
|
|
|
2025-03-27 21:48:07 +08:00
|
|
|
|
initial_state = initial_row_cuts[1:-1] + initial_col_cuts[1:-1]
|
|
|
|
|
cut_index = len(initial_row_cuts) - 2
|
2025-03-27 20:50:46 +08:00
|
|
|
|
|
2025-03-27 21:48:07 +08:00
|
|
|
|
# 初始化环境和Q-learning
|
|
|
|
|
env = FunctionEnv(params_file, initial_state, cut_index, car_paths, config)
|
|
|
|
|
q_learning = QLearning(config)
|
|
|
|
|
|
|
|
|
|
# 训练循环
|
|
|
|
|
for episode in range(config.NUM_EPISODES):
|
|
|
|
|
print(f"Episode {episode + 1} of {config.NUM_EPISODES}")
|
2025-03-27 20:50:46 +08:00
|
|
|
|
state = env.reset(initial_state)
|
|
|
|
|
done = False
|
2025-03-27 21:48:07 +08:00
|
|
|
|
|
2025-03-27 20:50:46 +08:00
|
|
|
|
while not done:
|
2025-03-27 21:48:07 +08:00
|
|
|
|
action = q_learning.choose_action(tuple(state), config.EPSILON)
|
2025-03-27 20:50:46 +08:00
|
|
|
|
next_state, reward, done = env.step(action)
|
2025-03-27 21:48:07 +08:00
|
|
|
|
next_state = tuple(next_state)
|
|
|
|
|
|
|
|
|
|
# Q-learning更新
|
|
|
|
|
if next_state not in q_learning.Q_table:
|
|
|
|
|
q_learning.Q_table[next_state] = {
|
|
|
|
|
a: 0.0 for a in config.ACTIONS}
|
|
|
|
|
max_next_Q = max(q_learning.Q_table[next_state].values())
|
|
|
|
|
current_Q = q_learning.get_Q(tuple(state), action)
|
|
|
|
|
new_Q = current_Q + config.ALPHA * \
|
|
|
|
|
(reward + config.GAMMA * max_next_Q - current_Q)
|
|
|
|
|
q_learning.set_Q(tuple(state), action, new_Q)
|
2025-03-27 20:50:46 +08:00
|
|
|
|
state = next_state
|
|
|
|
|
|
2025-03-27 21:48:07 +08:00
|
|
|
|
# 更新探索率
|
|
|
|
|
config.EPSILON = max(0.01, config.EPSILON * 0.999)
|
2025-03-27 20:50:46 +08:00
|
|
|
|
|
2025-03-27 21:48:07 +08:00
|
|
|
|
# 输出最优解
|
2025-03-27 20:50:46 +08:00
|
|
|
|
best_state = None
|
|
|
|
|
best_value = float('inf')
|
2025-03-27 21:48:07 +08:00
|
|
|
|
for state in q_learning.Q_table:
|
|
|
|
|
state = list(state)
|
|
|
|
|
state_value = env.calculate_objective(
|
|
|
|
|
state[:cut_index], state[cut_index:])
|
2025-03-27 20:50:46 +08:00
|
|
|
|
if state_value < best_value:
|
|
|
|
|
best_value = state_value
|
|
|
|
|
best_state = state
|
|
|
|
|
|
|
|
|
|
print("找到的最优状态:", best_state, "对应函数值:", best_value)
|
2025-03-27 21:48:07 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
main()
|