HPCC2025/Q_learning/q_table.py
weixin_46229132 1485fb2bd6 更新q_table
2025-03-27 21:48:07 +08:00

319 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import random
import numpy as np
import json
import math
import yaml
from typing import Tuple, List, Dict, Any, Optional
from dataclasses import dataclass
@dataclass
class Config:
"""配置类,集中管理所有超参数"""
# 学习参数
ALPHA: float = 0.1 # 学习率
GAMMA: float = 0.9 # 折扣因子
EPSILON: float = 0.2 # 探索率
NUM_EPISODES: int = 100 # 训练回合数
# 状态空间参数
STEP: float = 0.01 # 状态变化步长
VALUES: List[float] = None # 0.00~1.00的离散值列表
# 动作空间参数
ACTION_DELTA: List[float] = None # 动作变化量
# 环境参数
MIN_IMPROVEMENT: float = 0.001 # 最小改善阈值
MAX_NO_IMPROVEMENT: int = 10 # 最大允许连续未改善次数
TARGET_THRESHOLD: float = 10000 # 目标函数值的可接受阈值
def __post_init__(self):
"""初始化依赖参数"""
self.VALUES = [round(i*self.STEP, 2) for i in range(101)]
self.ACTION_DELTA = [self.STEP, -self.STEP]
# TODO 需要修改4个变量
self.ACTIONS = [(i, delta) for i in range(4)
for delta in self.ACTION_DELTA]
class FunctionEnv:
"""环境类:定义状态转移与奖励"""
def __init__(self, params_file: str, initial_state: Tuple[float, float, float], cut_index: int, car_paths: List[List[int]], config: Config):
self.params_file = params_file
with open(self.params_file + '.yml', 'r', encoding='utf-8') as file:
params = yaml.safe_load(file)
self.H = params['H']
self.W = params['W']
self.num_cars = params['num_cars']
self.flight_time_factor = params['flight_time_factor']
self.comp_time_factor = params['comp_time_factor']
self.trans_time_factor = params['trans_time_factor']
self.car_time_factor = params['car_time_factor']
self.bs_time_factor = params['bs_time_factor']
self.flight_energy_factor = params['flight_energy_factor']
self.comp_energy_factor = params['comp_energy_factor']
self.trans_energy_factor = params['trans_energy_factor']
self.battery_energy_capacity = params['battery_energy_capacity']
self.state = initial_state
self.cut_index = cut_index
self.config = config
self.car_paths = car_paths
self.best_value = float('inf')
self.no_improvement_count = 0
self.last_state = None
def step(self, action: Tuple[int, float]) -> Tuple[Tuple[float, float, float], float, bool]:
"""
执行一步动作
Args:
action: (变量索引, 变化量)的元组
Returns:
Tuple[Tuple[float, float, float], float, bool]: (下一个状态, 奖励, 是否结束)
"""
var_index, delta = action
new_state = list(self.state)
new_state[var_index] = round(new_state[var_index] + delta, 2)
row_cuts_state = new_state[:self.cut_index]
col_cuts_state = new_state[self.cut_index:]
# 检查约束条件
if not self._is_valid_state(row_cuts_state) or not self._is_valid_state(col_cuts_state):
return self.state, -10000.0, True
current_value = self.calculate_objective(
row_cuts_state, col_cuts_state)
# 检查终止条件
if self._should_terminate(new_state, current_value):
return new_state, 12000 - current_value, True
self._update_state(new_state, current_value)
return new_state, 12000 - current_value, False
def _is_valid_state(self, state: List[float]) -> bool:
"""
检查状态是否满足约束条件
确保列表中的元素严格递增且在(0,1)范围内
Args:
state: 需要检查的状态列表
Returns:
bool: 是否满足约束条件
"""
# 检查列表是否为空
if not state:
return False
# 检查所有元素是否在(0,1)范围内
if not all(0 < x < 1 for x in state):
return False
# 检查是否严格递增
for i in range(len(state) - 1):
if state[i] >= state[i + 1]:
return False
return True
def _should_terminate(self, state: Tuple[float, float, float], value: float) -> bool:
"""检查是否应该终止"""
if value < self.config.TARGET_THRESHOLD:
return True
if self.last_state is not None:
state_diff = sum(abs(a - b)
for a, b in zip(state, self.last_state))
if state_diff < self.config.MIN_IMPROVEMENT:
self.no_improvement_count += 1
else:
self.no_improvement_count = 0
if value < self.best_value:
self.best_value = value
self.no_improvement_count = 0
else:
self.no_improvement_count += 1
return self.no_improvement_count >= self.config.MAX_NO_IMPROVEMENT
def _update_state(self, state: Tuple[float, float, float], value: float) -> None:
"""更新状态"""
self.last_state = self.state
self.state = state
def calculate_objective(self, row_cuts, col_cuts):
"""
计算切分比例的目标值 T占位函数
:param row_cuts: 行切分比例
:param col_cuts: 列切分比例
:return: 目标值 T
"""
row_cuts = [0] + row_cuts + [1]
col_cuts = [0] + col_cuts + [1]
rectangles = []
for i in range(len(row_cuts) - 1):
for j in range(len(col_cuts) - 1):
d = (col_cuts[j+1] - col_cuts[j]) * self.W * \
(row_cuts[i+1] - row_cuts[i]) * self.H
rho_time_limit = (self.flight_time_factor - self.trans_time_factor) / \
(self.comp_time_factor - self.trans_time_factor)
rho_energy_limit = (self.battery_energy_capacity - self.flight_energy_factor * d - self.trans_energy_factor * d) / \
(self.comp_energy_factor * d -
self.trans_energy_factor * d)
if rho_energy_limit < 0:
return 1000000
rho = min(rho_time_limit, rho_energy_limit)
flight_time = self.flight_time_factor * d
bs_time = self.bs_time_factor * (1 - rho) * d
rectangles.append({
'flight_time': flight_time,
'bs_time': bs_time,
'center': ((row_cuts[i] + row_cuts[i+1]) / 2.0 * self.H,
(col_cuts[j] + col_cuts[j+1]) / 2.0 * self.W)
})
mortorcade_time_lt = []
for idx in range(self.num_cars):
car_path = self.car_paths[idx]
flight_time = sum(rectangles[point]['flight_time']
for point in car_path)
bs_time = sum(rectangles[point]['bs_time'] for point in car_path)
car_time = 0
for i in range(len(car_path) - 1):
first_point = car_path[i]
second_point = car_path[i + 1]
car_time += math.dist(
rectangles[first_point]['center'], rectangles[second_point]['center']) * self.car_time_factor
car_time += math.dist(rectangles[car_path[0]]['center'],
[self.H / 2, self.W / 2]) * self.car_time_factor
car_time += math.dist(rectangles[car_path[-1]]['center'],
[self.H / 2, self.W / 2]) * self.car_time_factor
mortorcade_time_lt.append(max(car_time + flight_time, bs_time))
return max(mortorcade_time_lt)
def reset(self, state: Tuple[float, float, float]) -> Tuple[float, float, float]:
"""重置环境状态"""
self.state = state
return self.state
class QLearning:
"""Q-learning算法实现"""
def __init__(self, config: Config):
self.config = config
self.Q_table: Dict[Tuple[float, float, float],
Dict[Tuple[int, float], float]] = {}
def get_Q(self, state: Tuple[float, float, float], action: Tuple[int, float]) -> float:
"""获取Q值"""
if state not in self.Q_table:
self.Q_table[state] = {a: 0.0 for a in self.config.ACTIONS}
return self.Q_table[state][action]
def set_Q(self, state: Tuple[float, float, float], action: Tuple[int, float], value: float) -> None:
"""设置Q值"""
if state not in self.Q_table:
self.Q_table[state] = {a: 0.0 for a in self.config.ACTIONS}
self.Q_table[state][action] = value
def choose_action(self, state: Tuple[float, float, float], epsilon: float) -> Tuple[int, float]:
"""选择动作(ε-greedy策略"""
if random.random() < epsilon:
return random.choice(self.config.ACTIONS)
if state not in self.Q_table:
self.Q_table[state] = {a: 0.0 for a in self.config.ACTIONS}
return max(self.Q_table[state].items(), key=lambda x: x[1])[0]
def load_initial_solution(file_path: str) -> Tuple[List[float], List[float], List[List[int]]]:
"""
从JSON文件加载初始解
Args:
file_path: JSON文件路径
Returns:
Tuple[List[float], List[float], List[List[int]]]: (行切分比例, 列切分比例, 车辆路径)
"""
with open(file_path, 'r', encoding='utf-8') as file:
data = json.load(file)
return data['row_boundaries'], data['col_boundaries'], data['car_paths']
def main():
"""主函数"""
# 初始化配置
config = Config()
random.seed(42)
# 加载初始解
solution_path = r"solutions\trav_ga_params2_parallel.json"
params_file = r"params2"
initial_row_cuts, initial_col_cuts, car_paths = load_initial_solution(
solution_path)
initial_state = initial_row_cuts[1:-1] + initial_col_cuts[1:-1]
cut_index = len(initial_row_cuts) - 2
# 初始化环境和Q-learning
env = FunctionEnv(params_file, initial_state, cut_index, car_paths, config)
q_learning = QLearning(config)
# 训练循环
for episode in range(config.NUM_EPISODES):
print(f"Episode {episode + 1} of {config.NUM_EPISODES}")
state = env.reset(initial_state)
done = False
while not done:
action = q_learning.choose_action(tuple(state), config.EPSILON)
next_state, reward, done = env.step(action)
next_state = tuple(next_state)
# Q-learning更新
if next_state not in q_learning.Q_table:
q_learning.Q_table[next_state] = {
a: 0.0 for a in config.ACTIONS}
max_next_Q = max(q_learning.Q_table[next_state].values())
current_Q = q_learning.get_Q(tuple(state), action)
new_Q = current_Q + config.ALPHA * \
(reward + config.GAMMA * max_next_Q - current_Q)
q_learning.set_Q(tuple(state), action, new_Q)
state = next_state
# 更新探索率
config.EPSILON = max(0.01, config.EPSILON * 0.999)
# 输出最优解
best_state = None
best_value = float('inf')
for state in q_learning.Q_table:
state = list(state)
state_value = env.calculate_objective(
state[:cut_index], state[cut_index:])
if state_value < best_value:
best_value = state_value
best_state = state
print("找到的最优状态:", best_state, "对应函数值:", best_value)
if __name__ == "__main__":
main()