HPCC2025/Q_learning/q_table.py

319 lines
12 KiB
Python
Raw Normal View History

2025-03-27 20:50:46 +08:00
import random
import numpy as np
import json
import math
import yaml
2025-03-27 21:48:07 +08:00
from typing import Tuple, List, Dict, Any, Optional
from dataclasses import dataclass
@dataclass
class Config:
"""配置类,集中管理所有超参数"""
# 学习参数
ALPHA: float = 0.1 # 学习率
GAMMA: float = 0.9 # 折扣因子
EPSILON: float = 0.2 # 探索率
NUM_EPISODES: int = 100 # 训练回合数
# 状态空间参数
STEP: float = 0.01 # 状态变化步长
VALUES: List[float] = None # 0.00~1.00的离散值列表
# 动作空间参数
ACTION_DELTA: List[float] = None # 动作变化量
# 环境参数
MIN_IMPROVEMENT: float = 0.001 # 最小改善阈值
MAX_NO_IMPROVEMENT: int = 10 # 最大允许连续未改善次数
TARGET_THRESHOLD: float = 10000 # 目标函数值的可接受阈值
def __post_init__(self):
"""初始化依赖参数"""
self.VALUES = [round(i*self.STEP, 2) for i in range(101)]
self.ACTION_DELTA = [self.STEP, -self.STEP]
# TODO 需要修改4个变量
self.ACTIONS = [(i, delta) for i in range(4)
for delta in self.ACTION_DELTA]
2025-03-27 20:50:46 +08:00
class FunctionEnv:
2025-03-27 21:48:07 +08:00
"""环境类:定义状态转移与奖励"""
def __init__(self, params_file: str, initial_state: Tuple[float, float, float], cut_index: int, car_paths: List[List[int]], config: Config):
self.params_file = params_file
with open(self.params_file + '.yml', 'r', encoding='utf-8') as file:
params = yaml.safe_load(file)
self.H = params['H']
self.W = params['W']
self.num_cars = params['num_cars']
self.flight_time_factor = params['flight_time_factor']
self.comp_time_factor = params['comp_time_factor']
self.trans_time_factor = params['trans_time_factor']
self.car_time_factor = params['car_time_factor']
self.bs_time_factor = params['bs_time_factor']
self.flight_energy_factor = params['flight_energy_factor']
self.comp_energy_factor = params['comp_energy_factor']
self.trans_energy_factor = params['trans_energy_factor']
self.battery_energy_capacity = params['battery_energy_capacity']
self.state = initial_state
self.cut_index = cut_index
self.config = config
self.car_paths = car_paths
self.best_value = float('inf')
self.no_improvement_count = 0
self.last_state = None
def step(self, action: Tuple[int, float]) -> Tuple[Tuple[float, float, float], float, bool]:
"""
执行一步动作
Args:
action: (变量索引, 变化量)的元组
Returns:
Tuple[Tuple[float, float, float], float, bool]: (下一个状态, 奖励, 是否结束)
"""
2025-03-27 20:50:46 +08:00
var_index, delta = action
new_state = list(self.state)
new_state[var_index] = round(new_state[var_index] + delta, 2)
2025-03-27 21:48:07 +08:00
row_cuts_state = new_state[:self.cut_index]
col_cuts_state = new_state[self.cut_index:]
# 检查约束条件
if not self._is_valid_state(row_cuts_state) or not self._is_valid_state(col_cuts_state):
2025-03-27 20:50:46 +08:00
return self.state, -10000.0, True
2025-03-27 21:48:07 +08:00
current_value = self.calculate_objective(
row_cuts_state, col_cuts_state)
# 检查终止条件
if self._should_terminate(new_state, current_value):
return new_state, 12000 - current_value, True
self._update_state(new_state, current_value)
return new_state, 12000 - current_value, False
def _is_valid_state(self, state: List[float]) -> bool:
"""
检查状态是否满足约束条件
确保列表中的元素严格递增且在(0,1)范围内
2025-03-27 20:50:46 +08:00
2025-03-27 21:48:07 +08:00
Args:
state: 需要检查的状态列表
2025-03-27 20:50:46 +08:00
2025-03-27 21:48:07 +08:00
Returns:
bool: 是否满足约束条件
"""
# 检查列表是否为空
if not state:
return False
# 检查所有元素是否在(0,1)范围内
if not all(0 < x < 1 for x in state):
return False
# 检查是否严格递增
for i in range(len(state) - 1):
if state[i] >= state[i + 1]:
return False
return True
def _should_terminate(self, state: Tuple[float, float, float], value: float) -> bool:
"""检查是否应该终止"""
if value < self.config.TARGET_THRESHOLD:
return True
2025-03-27 20:50:46 +08:00
if self.last_state is not None:
2025-03-27 21:48:07 +08:00
state_diff = sum(abs(a - b)
for a, b in zip(state, self.last_state))
if state_diff < self.config.MIN_IMPROVEMENT:
2025-03-27 20:50:46 +08:00
self.no_improvement_count += 1
else:
self.no_improvement_count = 0
2025-03-27 21:48:07 +08:00
if value < self.best_value:
self.best_value = value
2025-03-27 20:50:46 +08:00
self.no_improvement_count = 0
else:
self.no_improvement_count += 1
2025-03-27 21:48:07 +08:00
return self.no_improvement_count >= self.config.MAX_NO_IMPROVEMENT
def _update_state(self, state: Tuple[float, float, float], value: float) -> None:
"""更新状态"""
self.last_state = self.state
self.state = state
def calculate_objective(self, row_cuts, col_cuts):
"""
计算切分比例的目标值 T占位函数
:param row_cuts: 行切分比例
:param col_cuts: 列切分比例
:return: 目标值 T
"""
row_cuts = [0] + row_cuts + [1]
col_cuts = [0] + col_cuts + [1]
rectangles = []
for i in range(len(row_cuts) - 1):
for j in range(len(col_cuts) - 1):
d = (col_cuts[j+1] - col_cuts[j]) * self.W * \
(row_cuts[i+1] - row_cuts[i]) * self.H
rho_time_limit = (self.flight_time_factor - self.trans_time_factor) / \
(self.comp_time_factor - self.trans_time_factor)
rho_energy_limit = (self.battery_energy_capacity - self.flight_energy_factor * d - self.trans_energy_factor * d) / \
(self.comp_energy_factor * d -
self.trans_energy_factor * d)
if rho_energy_limit < 0:
return 1000000
rho = min(rho_time_limit, rho_energy_limit)
flight_time = self.flight_time_factor * d
bs_time = self.bs_time_factor * (1 - rho) * d
rectangles.append({
'flight_time': flight_time,
'bs_time': bs_time,
'center': ((row_cuts[i] + row_cuts[i+1]) / 2.0 * self.H,
(col_cuts[j] + col_cuts[j+1]) / 2.0 * self.W)
})
mortorcade_time_lt = []
for idx in range(self.num_cars):
car_path = self.car_paths[idx]
flight_time = sum(rectangles[point]['flight_time']
for point in car_path)
bs_time = sum(rectangles[point]['bs_time'] for point in car_path)
car_time = 0
for i in range(len(car_path) - 1):
first_point = car_path[i]
second_point = car_path[i + 1]
car_time += math.dist(
rectangles[first_point]['center'], rectangles[second_point]['center']) * self.car_time_factor
car_time += math.dist(rectangles[car_path[0]]['center'],
[self.H / 2, self.W / 2]) * self.car_time_factor
car_time += math.dist(rectangles[car_path[-1]]['center'],
[self.H / 2, self.W / 2]) * self.car_time_factor
mortorcade_time_lt.append(max(car_time + flight_time, bs_time))
return max(mortorcade_time_lt)
def reset(self, state: Tuple[float, float, float]) -> Tuple[float, float, float]:
"""重置环境状态"""
2025-03-27 20:50:46 +08:00
self.state = state
return self.state
2025-03-27 21:48:07 +08:00
class QLearning:
"""Q-learning算法实现"""
def __init__(self, config: Config):
self.config = config
self.Q_table: Dict[Tuple[float, float, float],
Dict[Tuple[int, float], float]] = {}
def get_Q(self, state: Tuple[float, float, float], action: Tuple[int, float]) -> float:
"""获取Q值"""
if state not in self.Q_table:
self.Q_table[state] = {a: 0.0 for a in self.config.ACTIONS}
return self.Q_table[state][action]
def set_Q(self, state: Tuple[float, float, float], action: Tuple[int, float], value: float) -> None:
"""设置Q值"""
if state not in self.Q_table:
self.Q_table[state] = {a: 0.0 for a in self.config.ACTIONS}
self.Q_table[state][action] = value
def choose_action(self, state: Tuple[float, float, float], epsilon: float) -> Tuple[int, float]:
"""选择动作(ε-greedy策略"""
if random.random() < epsilon:
return random.choice(self.config.ACTIONS)
if state not in self.Q_table:
self.Q_table[state] = {a: 0.0 for a in self.config.ACTIONS}
return max(self.Q_table[state].items(), key=lambda x: x[1])[0]
def load_initial_solution(file_path: str) -> Tuple[List[float], List[float], List[List[int]]]:
2025-03-27 20:50:46 +08:00
"""
2025-03-27 21:48:07 +08:00
从JSON文件加载初始解
Args:
file_path: JSON文件路径
Returns:
Tuple[List[float], List[float], List[List[int]]]: (行切分比例, 列切分比例, 车辆路径)
2025-03-27 20:50:46 +08:00
"""
with open(file_path, 'r', encoding='utf-8') as file:
data = json.load(file)
2025-03-27 21:48:07 +08:00
return data['row_boundaries'], data['col_boundaries'], data['car_paths']
2025-03-27 20:50:46 +08:00
2025-03-27 21:48:07 +08:00
def main():
"""主函数"""
# 初始化配置
config = Config()
2025-03-27 20:50:46 +08:00
random.seed(42)
2025-03-27 21:48:07 +08:00
# 加载初始解
2025-03-27 20:50:46 +08:00
solution_path = r"solutions\trav_ga_params2_parallel.json"
params_file = r"params2"
initial_row_cuts, initial_col_cuts, car_paths = load_initial_solution(
solution_path)
2025-03-27 21:48:07 +08:00
initial_state = initial_row_cuts[1:-1] + initial_col_cuts[1:-1]
cut_index = len(initial_row_cuts) - 2
2025-03-27 20:50:46 +08:00
2025-03-27 21:48:07 +08:00
# 初始化环境和Q-learning
env = FunctionEnv(params_file, initial_state, cut_index, car_paths, config)
q_learning = QLearning(config)
# 训练循环
for episode in range(config.NUM_EPISODES):
print(f"Episode {episode + 1} of {config.NUM_EPISODES}")
2025-03-27 20:50:46 +08:00
state = env.reset(initial_state)
done = False
2025-03-27 21:48:07 +08:00
2025-03-27 20:50:46 +08:00
while not done:
2025-03-27 21:48:07 +08:00
action = q_learning.choose_action(tuple(state), config.EPSILON)
2025-03-27 20:50:46 +08:00
next_state, reward, done = env.step(action)
2025-03-27 21:48:07 +08:00
next_state = tuple(next_state)
# Q-learning更新
if next_state not in q_learning.Q_table:
q_learning.Q_table[next_state] = {
a: 0.0 for a in config.ACTIONS}
max_next_Q = max(q_learning.Q_table[next_state].values())
current_Q = q_learning.get_Q(tuple(state), action)
new_Q = current_Q + config.ALPHA * \
(reward + config.GAMMA * max_next_Q - current_Q)
q_learning.set_Q(tuple(state), action, new_Q)
2025-03-27 20:50:46 +08:00
state = next_state
2025-03-27 21:48:07 +08:00
# 更新探索率
config.EPSILON = max(0.01, config.EPSILON * 0.999)
2025-03-27 20:50:46 +08:00
2025-03-27 21:48:07 +08:00
# 输出最优解
2025-03-27 20:50:46 +08:00
best_state = None
best_value = float('inf')
2025-03-27 21:48:07 +08:00
for state in q_learning.Q_table:
state = list(state)
state_value = env.calculate_objective(
state[:cut_index], state[cut_index:])
2025-03-27 20:50:46 +08:00
if state_value < best_value:
best_value = state_value
best_state = state
print("找到的最优状态:", best_state, "对应函数值:", best_value)
2025-03-27 21:48:07 +08:00
if __name__ == "__main__":
main()