319 lines
12 KiB
Python
319 lines
12 KiB
Python
import random
|
||
import numpy as np
|
||
import json
|
||
import math
|
||
import yaml
|
||
from typing import Tuple, List, Dict, Any, Optional
|
||
from dataclasses import dataclass
|
||
|
||
|
||
@dataclass
|
||
class Config:
|
||
"""配置类,集中管理所有超参数"""
|
||
# 学习参数
|
||
ALPHA: float = 0.1 # 学习率
|
||
GAMMA: float = 0.9 # 折扣因子
|
||
EPSILON: float = 0.2 # 探索率
|
||
NUM_EPISODES: int = 100 # 训练回合数
|
||
|
||
# 状态空间参数
|
||
STEP: float = 0.01 # 状态变化步长
|
||
VALUES: List[float] = None # 0.00~1.00的离散值列表
|
||
|
||
# 动作空间参数
|
||
ACTION_DELTA: List[float] = None # 动作变化量
|
||
|
||
# 环境参数
|
||
MIN_IMPROVEMENT: float = 0.001 # 最小改善阈值
|
||
MAX_NO_IMPROVEMENT: int = 10 # 最大允许连续未改善次数
|
||
TARGET_THRESHOLD: float = 10000 # 目标函数值的可接受阈值
|
||
|
||
def __post_init__(self):
|
||
"""初始化依赖参数"""
|
||
self.VALUES = [round(i*self.STEP, 2) for i in range(101)]
|
||
self.ACTION_DELTA = [self.STEP, -self.STEP]
|
||
# TODO 需要修改4个变量
|
||
self.ACTIONS = [(i, delta) for i in range(4)
|
||
for delta in self.ACTION_DELTA]
|
||
|
||
|
||
class FunctionEnv:
|
||
"""环境类:定义状态转移与奖励"""
|
||
|
||
def __init__(self, params_file: str, initial_state: Tuple[float, float, float], cut_index: int, car_paths: List[List[int]], config: Config):
|
||
self.params_file = params_file
|
||
with open(self.params_file + '.yml', 'r', encoding='utf-8') as file:
|
||
params = yaml.safe_load(file)
|
||
|
||
self.H = params['H']
|
||
self.W = params['W']
|
||
self.num_cars = params['num_cars']
|
||
|
||
self.flight_time_factor = params['flight_time_factor']
|
||
self.comp_time_factor = params['comp_time_factor']
|
||
self.trans_time_factor = params['trans_time_factor']
|
||
self.car_time_factor = params['car_time_factor']
|
||
self.bs_time_factor = params['bs_time_factor']
|
||
|
||
self.flight_energy_factor = params['flight_energy_factor']
|
||
self.comp_energy_factor = params['comp_energy_factor']
|
||
self.trans_energy_factor = params['trans_energy_factor']
|
||
self.battery_energy_capacity = params['battery_energy_capacity']
|
||
|
||
self.state = initial_state
|
||
self.cut_index = cut_index
|
||
self.config = config
|
||
self.car_paths = car_paths
|
||
self.best_value = float('inf')
|
||
self.no_improvement_count = 0
|
||
self.last_state = None
|
||
|
||
def step(self, action: Tuple[int, float]) -> Tuple[Tuple[float, float, float], float, bool]:
|
||
"""
|
||
执行一步动作
|
||
|
||
Args:
|
||
action: (变量索引, 变化量)的元组
|
||
|
||
Returns:
|
||
Tuple[Tuple[float, float, float], float, bool]: (下一个状态, 奖励, 是否结束)
|
||
"""
|
||
var_index, delta = action
|
||
new_state = list(self.state)
|
||
new_state[var_index] = round(new_state[var_index] + delta, 2)
|
||
row_cuts_state = new_state[:self.cut_index]
|
||
col_cuts_state = new_state[self.cut_index:]
|
||
|
||
# 检查约束条件
|
||
if not self._is_valid_state(row_cuts_state) or not self._is_valid_state(col_cuts_state):
|
||
return self.state, -10000.0, True
|
||
|
||
current_value = self.calculate_objective(
|
||
row_cuts_state, col_cuts_state)
|
||
|
||
# 检查终止条件
|
||
if self._should_terminate(new_state, current_value):
|
||
return new_state, 12000 - current_value, True
|
||
|
||
self._update_state(new_state, current_value)
|
||
return new_state, 12000 - current_value, False
|
||
|
||
def _is_valid_state(self, state: List[float]) -> bool:
|
||
"""
|
||
检查状态是否满足约束条件
|
||
确保列表中的元素严格递增且在(0,1)范围内
|
||
|
||
Args:
|
||
state: 需要检查的状态列表
|
||
|
||
Returns:
|
||
bool: 是否满足约束条件
|
||
"""
|
||
# 检查列表是否为空
|
||
if not state:
|
||
return False
|
||
|
||
# 检查所有元素是否在(0,1)范围内
|
||
if not all(0 < x < 1 for x in state):
|
||
return False
|
||
|
||
# 检查是否严格递增
|
||
for i in range(len(state) - 1):
|
||
if state[i] >= state[i + 1]:
|
||
return False
|
||
|
||
return True
|
||
|
||
def _should_terminate(self, state: Tuple[float, float, float], value: float) -> bool:
|
||
"""检查是否应该终止"""
|
||
if value < self.config.TARGET_THRESHOLD:
|
||
return True
|
||
|
||
if self.last_state is not None:
|
||
state_diff = sum(abs(a - b)
|
||
for a, b in zip(state, self.last_state))
|
||
if state_diff < self.config.MIN_IMPROVEMENT:
|
||
self.no_improvement_count += 1
|
||
else:
|
||
self.no_improvement_count = 0
|
||
|
||
if value < self.best_value:
|
||
self.best_value = value
|
||
self.no_improvement_count = 0
|
||
else:
|
||
self.no_improvement_count += 1
|
||
|
||
return self.no_improvement_count >= self.config.MAX_NO_IMPROVEMENT
|
||
|
||
def _update_state(self, state: Tuple[float, float, float], value: float) -> None:
|
||
"""更新状态"""
|
||
self.last_state = self.state
|
||
self.state = state
|
||
|
||
def calculate_objective(self, row_cuts, col_cuts):
|
||
"""
|
||
计算切分比例的目标值 T(占位函数)
|
||
:param row_cuts: 行切分比例
|
||
:param col_cuts: 列切分比例
|
||
:return: 目标值 T
|
||
"""
|
||
row_cuts = [0] + row_cuts + [1]
|
||
col_cuts = [0] + col_cuts + [1]
|
||
rectangles = []
|
||
for i in range(len(row_cuts) - 1):
|
||
for j in range(len(col_cuts) - 1):
|
||
d = (col_cuts[j+1] - col_cuts[j]) * self.W * \
|
||
(row_cuts[i+1] - row_cuts[i]) * self.H
|
||
rho_time_limit = (self.flight_time_factor - self.trans_time_factor) / \
|
||
(self.comp_time_factor - self.trans_time_factor)
|
||
rho_energy_limit = (self.battery_energy_capacity - self.flight_energy_factor * d - self.trans_energy_factor * d) / \
|
||
(self.comp_energy_factor * d -
|
||
self.trans_energy_factor * d)
|
||
if rho_energy_limit < 0:
|
||
return 1000000
|
||
rho = min(rho_time_limit, rho_energy_limit)
|
||
|
||
flight_time = self.flight_time_factor * d
|
||
bs_time = self.bs_time_factor * (1 - rho) * d
|
||
|
||
rectangles.append({
|
||
'flight_time': flight_time,
|
||
'bs_time': bs_time,
|
||
'center': ((row_cuts[i] + row_cuts[i+1]) / 2.0 * self.H,
|
||
(col_cuts[j] + col_cuts[j+1]) / 2.0 * self.W)
|
||
})
|
||
|
||
mortorcade_time_lt = []
|
||
for idx in range(self.num_cars):
|
||
car_path = self.car_paths[idx]
|
||
|
||
flight_time = sum(rectangles[point]['flight_time']
|
||
for point in car_path)
|
||
bs_time = sum(rectangles[point]['bs_time'] for point in car_path)
|
||
|
||
car_time = 0
|
||
for i in range(len(car_path) - 1):
|
||
first_point = car_path[i]
|
||
second_point = car_path[i + 1]
|
||
car_time += math.dist(
|
||
rectangles[first_point]['center'], rectangles[second_point]['center']) * self.car_time_factor
|
||
car_time += math.dist(rectangles[car_path[0]]['center'],
|
||
[self.H / 2, self.W / 2]) * self.car_time_factor
|
||
car_time += math.dist(rectangles[car_path[-1]]['center'],
|
||
[self.H / 2, self.W / 2]) * self.car_time_factor
|
||
mortorcade_time_lt.append(max(car_time + flight_time, bs_time))
|
||
|
||
return max(mortorcade_time_lt)
|
||
|
||
def reset(self, state: Tuple[float, float, float]) -> Tuple[float, float, float]:
|
||
"""重置环境状态"""
|
||
self.state = state
|
||
return self.state
|
||
|
||
|
||
class QLearning:
|
||
"""Q-learning算法实现"""
|
||
|
||
def __init__(self, config: Config):
|
||
self.config = config
|
||
self.Q_table: Dict[Tuple[float, float, float],
|
||
Dict[Tuple[int, float], float]] = {}
|
||
|
||
def get_Q(self, state: Tuple[float, float, float], action: Tuple[int, float]) -> float:
|
||
"""获取Q值"""
|
||
if state not in self.Q_table:
|
||
self.Q_table[state] = {a: 0.0 for a in self.config.ACTIONS}
|
||
return self.Q_table[state][action]
|
||
|
||
def set_Q(self, state: Tuple[float, float, float], action: Tuple[int, float], value: float) -> None:
|
||
"""设置Q值"""
|
||
if state not in self.Q_table:
|
||
self.Q_table[state] = {a: 0.0 for a in self.config.ACTIONS}
|
||
self.Q_table[state][action] = value
|
||
|
||
def choose_action(self, state: Tuple[float, float, float], epsilon: float) -> Tuple[int, float]:
|
||
"""选择动作(ε-greedy策略)"""
|
||
if random.random() < epsilon:
|
||
return random.choice(self.config.ACTIONS)
|
||
if state not in self.Q_table:
|
||
self.Q_table[state] = {a: 0.0 for a in self.config.ACTIONS}
|
||
return max(self.Q_table[state].items(), key=lambda x: x[1])[0]
|
||
|
||
|
||
def load_initial_solution(file_path: str) -> Tuple[List[float], List[float], List[List[int]]]:
|
||
"""
|
||
从JSON文件加载初始解
|
||
|
||
Args:
|
||
file_path: JSON文件路径
|
||
|
||
Returns:
|
||
Tuple[List[float], List[float], List[List[int]]]: (行切分比例, 列切分比例, 车辆路径)
|
||
"""
|
||
with open(file_path, 'r', encoding='utf-8') as file:
|
||
data = json.load(file)
|
||
return data['row_boundaries'], data['col_boundaries'], data['car_paths']
|
||
|
||
|
||
def main():
|
||
"""主函数"""
|
||
# 初始化配置
|
||
config = Config()
|
||
random.seed(42)
|
||
|
||
# 加载初始解
|
||
solution_path = r"solutions\trav_ga_params2_parallel.json"
|
||
params_file = r"params2"
|
||
|
||
initial_row_cuts, initial_col_cuts, car_paths = load_initial_solution(
|
||
solution_path)
|
||
|
||
initial_state = initial_row_cuts[1:-1] + initial_col_cuts[1:-1]
|
||
cut_index = len(initial_row_cuts) - 2
|
||
|
||
# 初始化环境和Q-learning
|
||
env = FunctionEnv(params_file, initial_state, cut_index, car_paths, config)
|
||
q_learning = QLearning(config)
|
||
|
||
# 训练循环
|
||
for episode in range(config.NUM_EPISODES):
|
||
print(f"Episode {episode + 1} of {config.NUM_EPISODES}")
|
||
state = env.reset(initial_state)
|
||
done = False
|
||
|
||
while not done:
|
||
action = q_learning.choose_action(tuple(state), config.EPSILON)
|
||
next_state, reward, done = env.step(action)
|
||
next_state = tuple(next_state)
|
||
|
||
# Q-learning更新
|
||
if next_state not in q_learning.Q_table:
|
||
q_learning.Q_table[next_state] = {
|
||
a: 0.0 for a in config.ACTIONS}
|
||
max_next_Q = max(q_learning.Q_table[next_state].values())
|
||
current_Q = q_learning.get_Q(tuple(state), action)
|
||
new_Q = current_Q + config.ALPHA * \
|
||
(reward + config.GAMMA * max_next_Q - current_Q)
|
||
q_learning.set_Q(tuple(state), action, new_Q)
|
||
state = next_state
|
||
|
||
# 更新探索率
|
||
config.EPSILON = max(0.01, config.EPSILON * 0.999)
|
||
|
||
# 输出最优解
|
||
best_state = None
|
||
best_value = float('inf')
|
||
for state in q_learning.Q_table:
|
||
state = list(state)
|
||
state_value = env.calculate_objective(
|
||
state[:cut_index], state[cut_index:])
|
||
if state_value < best_value:
|
||
best_value = state_value
|
||
best_state = state
|
||
|
||
print("找到的最优状态:", best_state, "对应函数值:", best_value)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|