From 1485fb2bd6fa694155dca3e12642abf994d99441 Mon Sep 17 00:00:00 2001 From: weixin_46229132 Date: Thu, 27 Mar 2025 21:48:07 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0q=5Ftable?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Q_learning/TSP.py | 216 +++++++++++++++++++++ Q_learning/q_table.py | 428 +++++++++++++++++++++++++----------------- 2 files changed, 472 insertions(+), 172 deletions(-) create mode 100644 Q_learning/TSP.py diff --git a/Q_learning/TSP.py b/Q_learning/TSP.py new file mode 100644 index 0000000..2d4507b --- /dev/null +++ b/Q_learning/TSP.py @@ -0,0 +1,216 @@ +%matplotlib inline +import pylab as plt +from IPython.display import clear_output +import numpy as np +import asyncio + +class TSP(object): + ''' + 用 Q-Learning 求解 TSP 问题 + 作者 Surfer Zen @ https://www.zhihu.com/people/surfer-zen + ''' + def __init__(self, + num_cities=15, + map_size=(800.0, 600.0), + alpha=2, + beta=1, + learning_rate=0.001, + eps=0.1, + ): + ''' + Args: + num_cities (int): 城市数目 + map_size (int, int): 地图尺寸(宽,高) + alpha (float): 一个超参,值越大,越优先探索最近的点 + beta (float): 一个超参,值越大,越优先探索可能导向总距离最优的点 + learning_rate (float): 学习率 + eps (float): 探索率,值越大,探索性越强,但越难收敛 + ''' + self.num_cities =num_cities + self.map_size = map_size + self.alpha = alpha + self.beta = beta + self.eps = eps + self.learning_rate = learning_rate + self.cities = self.generate_cities() + self.distances = self.get_dist_matrix() + self.mean_distance = self.distances.mean() + self.qualities = np.zeros([num_cities, num_cities]) + self.normalizers = np.zeros(num_cities) + self.best_path = None + self.best_path_length = np.inf + + + def generate_cities(self): + ''' + 随机生成城市(坐标) + Returns: + cities: [[x1, x2, x3...], [y1, y2, y3...]] 城市坐标 + ''' + max_width, max_height = self.map_size + cities = np.random.random([2, self.num_cities]) \ + * np.array([max_width, max_height]).reshape(2, -1) + return cities + + def get_dist_matrix(self): + ''' + 根据城市坐标,计算距离矩阵 + ''' + dist_matrix = np.zeros([self.num_cities, self.num_cities]) + for i in range(self.num_cities): + for j in range(self.num_cities): + if i == j: + continue + xi, xj = self.cities[0, i], self.cities[0, j] + yi, yj = self.cities[1, i], self.cities[1, j] + dist_matrix[i, j] = np.sqrt((xi-xj)**2 + (yi-yj)**2) + return dist_matrix + + def rollout(self, start_city_id=None): + ''' + 从 start_city 出发,根据策略,在城市间游走,直到所有城市都走了一遍 + ''' + cities_visited = [] + action_probs = [] + if start_city_id is None: + start_city_id = np.random.randint(self.num_cities) + current_city_id = start_city_id + cities_visited.append(current_city_id) + while len(cities_visited) < self.num_cities: + current_city_id, action_prob = self.choose_next_city(cities_visited) + cities_visited.append(current_city_id) + action_probs.append(action_prob) + cities_visited.append(cities_visited[0]) + action_probs.append(1.0) + + path_length = self.calc_path_length(cities_visited) + if path_length < self.best_path_length: + self.best_path = cities_visited + self.best_path_length = path_length + rewards = self.calc_path_rewards(cities_visited, path_length) + return cities_visited, action_probs, rewards + + def choose_next_city(self, cities_visited): + ''' + 根据策略选择下一个城市 + ''' + current_city_id = cities_visited[-1] + + # 对 quality 取指数,计算 softmax 概率用 + probabilities = np.exp(self.qualities[current_city_id]) + + # 将已经走过的城市概率设置为零 + for city_visited in cities_visited: + probabilities[city_visited] = 0 + + # 计算 softmax 概率 + probabilities = probabilities/probabilities.sum() + + if np.random.random() < self.eps: + # 以 eps 概率按softmax概率密度进行随机采样 + next_city_id = np.random.choice(range(len(probabilities)), p=probabilities) + else: + # 以 (1 - eps) 概率选择当前最优策略 + next_city_id = probabilities.argmax() + + # 计算当前决策/action 的概率 + if probabilities.argmax() == next_city_id: + action_prob = probabilities[next_city_id]*self.eps + (1-self.eps) + else: + action_prob = probabilities[next_city_id]*self.eps + + return next_city_id, action_prob + + def calc_path_rewards(self, path, path_length): + ''' + 计算给定路径的奖励/rewards + Args: + path (list[int]): 路径,每个元素代表城市的 id + path_length (float): 路径长路 + Returns: + rewards: 每一步的奖励,总距离以及当前这一步的距离越大,奖励越小 + ''' + rewards = [] + for fr, to in zip(path[:-1], path[1:]): + dist = self.distances[fr, to] + reward = (self.mean_distance/path_length)**self.beta + reward = reward*(self.mean_distance/dist)**self.alpha + rewards.append(np.log(reward)) + return rewards + + def calc_path_length(self, path): + ''' + 计算路径长度 + ''' + path_length = 0 + for fr, to in zip(path[:-1], path[1:]): + path_length += self.distances[fr, to] + return path_length + + def calc_updates_for_one_rollout(self, path, action_probs, rewards): + ''' + 对于给定的一次 rollout 的结果,计算其对应的 qualities 和 normalizers + ''' + qualities = [] + normalizers = [] + for fr, to, reward, action_prob in zip(path[:-1], path[1:], rewards, action_probs): + log_action_probability = np.log(action_prob) + qualities.append(- reward*log_action_probability) + normalizers.append(- (reward + 1)*log_action_probability) + return qualities, normalizers + + def update(self, path, new_qualities, new_normalizers): + ''' + 用渐近平均的思想,对 qualities 和 normalizers 进行更新 + ''' + lr = self.learning_rate + for fr, to, new_quality, new_normalizer in zip( + path[:-1], path[1:], new_qualities, new_normalizers): + self.normalizers[fr] = (1-lr)*self.normalizers[fr] + lr*new_normalizer + self.qualities[fr, to] = (1-lr)*self.qualities[fr, to] + lr*new_quality + + async def train_for_one_rollout(self, start_city_id): + ''' + 对一次 rollout 的结果进行训练的流程 + ''' + path, action_probs, rewards = self.rollout(start_city_id=start_city_id) + new_qualities, new_normalizers = self.calc_updates_for_one_rollout(path, action_probs, rewards) + self.update(path, new_qualities, new_normalizers) + + async def train_for_one_epoch(self): + ''' + 对一个 epoch 的结果进行训练的流程, + 一个 epoch 对应于从每个 city 出发进行一次 rollout + ''' + tasks = [self.train_for_one_rollout(start_city_id) for start_city_id in range(self.num_cities)] + await asyncio.gather(*tasks) + + async def train(self, num_epochs=1000, display=True): + ''' + 总训练流程 + ''' + for epoch in range(num_epochs): + await self.train_for_one_epoch() + if display: + self.draw(epoch) + + def draw(self, epoch): + ''' + 绘图 + ''' + _ = plt.scatter(*self.cities) + for fr, to in zip(self.best_path[:-1], self.best_path[1:]): + x1, y1 = self.cities[:, fr] + x2, y2 = self.cities[:, to] + dx, dy = x2-x1, y2-y1 + plt.arrow(x1, y1, dx, dy, width=0.01*min(self.map_size), + edgecolor='orange', facecolor='white', animated=True, + length_includes_head=True) + nrs = np.exp(self.qualities) + for i in range(self.num_cities): + nrs[i, i] = 0 + gap = np.abs(np.exp(self.normalizers) - nrs.sum(-1)).mean() + plt.title(f'epoch {epoch}: path length = {self.best_path_length:.2f}, normalizer error = {gap:.3f}') + plt.savefig('tsp.png') + plt.show() + clear_output(wait=True) \ No newline at end of file diff --git a/Q_learning/q_table.py b/Q_learning/q_table.py index 44fffd1..383f476 100644 --- a/Q_learning/q_table.py +++ b/Q_learning/q_table.py @@ -3,232 +3,316 @@ import numpy as np import json import math import yaml -# 参数设置 -STEP = 0.01 -VALUES = [round(i*STEP, 2) for i in range(101)] # 0.00~1.00 -ACTION_DELTA = [STEP, -STEP] # 增加或减少 0.01 -ACTIONS = [] # 每个动作为 (var_index, delta) -for i in range(3): - for delta in ACTION_DELTA: - ACTIONS.append((i, delta)) - -ALPHA = 0.1 # 学习率 -GAMMA = 0.9 # 折扣因子 -EPSILON = 0.2 # 探索率 -NUM_EPISODES = 100 +from typing import Tuple, List, Dict, Any, Optional +from dataclasses import dataclass -def f(state): - """ - 计算切分比例的目标值 T(占位函数) - :param row_cuts: 行切分比例 - :param col_cuts: 列切分比例 - :return: 目标值 T - """ - with open('params2.yml', 'r', encoding='utf-8') as file: - params = yaml.safe_load(file) - H = params['H'] - W = params['W'] - num_cars = params['num_cars'] +@dataclass +class Config: + """配置类,集中管理所有超参数""" + # 学习参数 + ALPHA: float = 0.1 # 学习率 + GAMMA: float = 0.9 # 折扣因子 + EPSILON: float = 0.2 # 探索率 + NUM_EPISODES: int = 100 # 训练回合数 - flight_time_factor = params['flight_time_factor'] - comp_time_factor = params['comp_time_factor'] - trans_time_factor = params['trans_time_factor'] - car_time_factor = params['car_time_factor'] - bs_time_factor = params['bs_time_factor'] + # 状态空间参数 + STEP: float = 0.01 # 状态变化步长 + VALUES: List[float] = None # 0.00~1.00的离散值列表 - flight_energy_factor = params['flight_energy_factor'] - comp_energy_factor = params['comp_energy_factor'] - trans_energy_factor = params['trans_energy_factor'] - battery_energy_capacity = params['battery_energy_capacity'] + # 动作空间参数 + ACTION_DELTA: List[float] = None # 动作变化量 - col_cuts = list(state) - col_cuts.insert(0, 0) - col_cuts.append(1) - row_cuts = [0, 0.5, 1] - rectangles = [] - for i in range(len(row_cuts) - 1): - for j in range(len(col_cuts) - 1): - d = (col_cuts[j+1] - col_cuts[j]) * W * \ - (row_cuts[i+1] - row_cuts[i]) * H - rho_time_limit = (flight_time_factor - trans_time_factor) / \ - (comp_time_factor - trans_time_factor) - rho_energy_limit = (battery_energy_capacity - flight_energy_factor * d - trans_energy_factor * d) / (comp_energy_factor * d - trans_energy_factor * d) - if rho_energy_limit < 0: - return 100000 - rho = min(rho_time_limit, rho_energy_limit) + # 环境参数 + MIN_IMPROVEMENT: float = 0.001 # 最小改善阈值 + MAX_NO_IMPROVEMENT: int = 10 # 最大允许连续未改善次数 + TARGET_THRESHOLD: float = 10000 # 目标函数值的可接受阈值 - flight_time = flight_time_factor * d - bs_time = bs_time_factor * (1 - rho) * d + def __post_init__(self): + """初始化依赖参数""" + self.VALUES = [round(i*self.STEP, 2) for i in range(101)] + self.ACTION_DELTA = [self.STEP, -self.STEP] + # TODO 需要修改4个变量 + self.ACTIONS = [(i, delta) for i in range(4) + for delta in self.ACTION_DELTA] - rectangles.append({ - 'flight_time': flight_time, - 'bs_time': bs_time, - 'center': ((row_cuts[i] + row_cuts[i+1]) / 2.0 * H, - (col_cuts[j] + col_cuts[j+1]) / 2.0 * W) - }) - mortorcade_time_lt = [] - for idx in range(num_cars): - car_path = car_paths[idx] - - flight_time = sum(rectangles[point]['flight_time'] - for point in car_path) - bs_time = sum(rectangles[point]['bs_time'] for point in car_path) - - car_time = 0 - for i in range(len(car_path) - 1): - first_point = car_path[i] - second_point = car_path[i + 1] - car_time += math.dist( - rectangles[first_point]['center'], rectangles[second_point]['center']) * car_time_factor - car_time += math.dist(rectangles[car_path[0]]['center'], - [H / 2, W / 2]) * car_time_factor - car_time += math.dist(rectangles[car_path[-1]]['center'], - [H / 2, W / 2]) * car_time_factor - mortorcade_time_lt.append(max(car_time + flight_time, bs_time)) - - return max(mortorcade_time_lt) - -# 环境类:定义状态转移与奖励 class FunctionEnv: - def __init__(self, initial_state): - self.state = initial_state # 初始状态 (x1,x2,x3) - self.best_value = float('inf') # 记录最佳值 - self.no_improvement_count = 0 # 记录连续未改善的次数 - self.last_state = None # 记录上一个状态 - self.min_improvement = 0.001 # 最小改善阈值 - self.max_no_improvement = 10 # 最大允许连续未改善次数 - self.target_threshold = 10000 # 目标函数值的可接受阈值 - - def step(self, action): - # action: (var_index, delta) + """环境类:定义状态转移与奖励""" + + def __init__(self, params_file: str, initial_state: Tuple[float, float, float], cut_index: int, car_paths: List[List[int]], config: Config): + self.params_file = params_file + with open(self.params_file + '.yml', 'r', encoding='utf-8') as file: + params = yaml.safe_load(file) + + self.H = params['H'] + self.W = params['W'] + self.num_cars = params['num_cars'] + + self.flight_time_factor = params['flight_time_factor'] + self.comp_time_factor = params['comp_time_factor'] + self.trans_time_factor = params['trans_time_factor'] + self.car_time_factor = params['car_time_factor'] + self.bs_time_factor = params['bs_time_factor'] + + self.flight_energy_factor = params['flight_energy_factor'] + self.comp_energy_factor = params['comp_energy_factor'] + self.trans_energy_factor = params['trans_energy_factor'] + self.battery_energy_capacity = params['battery_energy_capacity'] + + self.state = initial_state + self.cut_index = cut_index + self.config = config + self.car_paths = car_paths + self.best_value = float('inf') + self.no_improvement_count = 0 + self.last_state = None + + def step(self, action: Tuple[int, float]) -> Tuple[Tuple[float, float, float], float, bool]: + """ + 执行一步动作 + + Args: + action: (变量索引, 变化量)的元组 + + Returns: + Tuple[Tuple[float, float, float], float, bool]: (下一个状态, 奖励, 是否结束) + """ var_index, delta = action new_state = list(self.state) new_state[var_index] = round(new_state[var_index] + delta, 2) - # 保证取值在0-1范围内 - if new_state[var_index] < 0 or new_state[var_index] > 1: - return self.state, -10000.0, True # episode结束 - # 检查约束:x1 < x2 < x3 - if not (0 < new_state[0] < new_state[1] < new_state[2] < 1): + row_cuts_state = new_state[:self.cut_index] + col_cuts_state = new_state[self.cut_index:] + + # 检查约束条件 + if not self._is_valid_state(row_cuts_state) or not self._is_valid_state(col_cuts_state): return self.state, -10000.0, True + + current_value = self.calculate_objective( + row_cuts_state, col_cuts_state) + + # 检查终止条件 + if self._should_terminate(new_state, current_value): + return new_state, 12000 - current_value, True + + self._update_state(new_state, current_value) + return new_state, 12000 - current_value, False + + def _is_valid_state(self, state: List[float]) -> bool: + """ + 检查状态是否满足约束条件 + 确保列表中的元素严格递增且在(0,1)范围内 - next_state = tuple(new_state) - current_value = f(next_state) - - # 检查是否达到目标阈值 - if current_value < self.target_threshold: - return next_state, 12000 - current_value, True + Args: + state: 需要检查的状态列表 - # 检查状态变化是否很小 + Returns: + bool: 是否满足约束条件 + """ + # 检查列表是否为空 + if not state: + return False + + # 检查所有元素是否在(0,1)范围内 + if not all(0 < x < 1 for x in state): + return False + + # 检查是否严格递增 + for i in range(len(state) - 1): + if state[i] >= state[i + 1]: + return False + + return True + + def _should_terminate(self, state: Tuple[float, float, float], value: float) -> bool: + """检查是否应该终止""" + if value < self.config.TARGET_THRESHOLD: + return True + if self.last_state is not None: - state_diff = sum(abs(a - b) for a, b in zip(next_state, self.last_state)) - if state_diff < self.min_improvement: + state_diff = sum(abs(a - b) + for a, b in zip(state, self.last_state)) + if state_diff < self.config.MIN_IMPROVEMENT: self.no_improvement_count += 1 else: self.no_improvement_count = 0 - - # 检查是否有改善 - if current_value < self.best_value: - self.best_value = current_value + + if value < self.best_value: + self.best_value = value self.no_improvement_count = 0 else: self.no_improvement_count += 1 - - # 如果连续多次没有改善,结束episode - if self.no_improvement_count >= self.max_no_improvement: - return next_state, 12000 - current_value, True - - self.last_state = next_state - self.state = next_state - return next_state, 12000 - current_value, False - - def reset(self, state): + + return self.no_improvement_count >= self.config.MAX_NO_IMPROVEMENT + + def _update_state(self, state: Tuple[float, float, float], value: float) -> None: + """更新状态""" + self.last_state = self.state + self.state = state + + def calculate_objective(self, row_cuts, col_cuts): + """ + 计算切分比例的目标值 T(占位函数) + :param row_cuts: 行切分比例 + :param col_cuts: 列切分比例 + :return: 目标值 T + """ + row_cuts = [0] + row_cuts + [1] + col_cuts = [0] + col_cuts + [1] + rectangles = [] + for i in range(len(row_cuts) - 1): + for j in range(len(col_cuts) - 1): + d = (col_cuts[j+1] - col_cuts[j]) * self.W * \ + (row_cuts[i+1] - row_cuts[i]) * self.H + rho_time_limit = (self.flight_time_factor - self.trans_time_factor) / \ + (self.comp_time_factor - self.trans_time_factor) + rho_energy_limit = (self.battery_energy_capacity - self.flight_energy_factor * d - self.trans_energy_factor * d) / \ + (self.comp_energy_factor * d - + self.trans_energy_factor * d) + if rho_energy_limit < 0: + return 1000000 + rho = min(rho_time_limit, rho_energy_limit) + + flight_time = self.flight_time_factor * d + bs_time = self.bs_time_factor * (1 - rho) * d + + rectangles.append({ + 'flight_time': flight_time, + 'bs_time': bs_time, + 'center': ((row_cuts[i] + row_cuts[i+1]) / 2.0 * self.H, + (col_cuts[j] + col_cuts[j+1]) / 2.0 * self.W) + }) + + mortorcade_time_lt = [] + for idx in range(self.num_cars): + car_path = self.car_paths[idx] + + flight_time = sum(rectangles[point]['flight_time'] + for point in car_path) + bs_time = sum(rectangles[point]['bs_time'] for point in car_path) + + car_time = 0 + for i in range(len(car_path) - 1): + first_point = car_path[i] + second_point = car_path[i + 1] + car_time += math.dist( + rectangles[first_point]['center'], rectangles[second_point]['center']) * self.car_time_factor + car_time += math.dist(rectangles[car_path[0]]['center'], + [self.H / 2, self.W / 2]) * self.car_time_factor + car_time += math.dist(rectangles[car_path[-1]]['center'], + [self.H / 2, self.W / 2]) * self.car_time_factor + mortorcade_time_lt.append(max(car_time + flight_time, bs_time)) + + return max(mortorcade_time_lt) + + def reset(self, state: Tuple[float, float, float]) -> Tuple[float, float, float]: + """重置环境状态""" self.state = state return self.state -# 初始化 Q-table:使用字典表示,key 为状态 tuple,value 为 dict: action->Q值 -Q_table = {} -def get_Q(state, action): - if state not in Q_table: - Q_table[state] = {a: 0.0 for a in ACTIONS} - return Q_table[state][action] +class QLearning: + """Q-learning算法实现""" -def set_Q(state, action, value): - if state not in Q_table: - Q_table[state] = {a: 0.0 for a in ACTIONS} - Q_table[state][action] = value + def __init__(self, config: Config): + self.config = config + self.Q_table: Dict[Tuple[float, float, float], + Dict[Tuple[int, float], float]] = {} -def choose_action(state, epsilon): - # ε-greedy 策略 - if random.random() < epsilon: - return random.choice(ACTIONS) - else: - if state not in Q_table: - Q_table[state] = {a: 0.0 for a in ACTIONS} - # 返回Q值最大的动作 - return max(Q_table[state].items(), key=lambda x: x[1])[0] + def get_Q(self, state: Tuple[float, float, float], action: Tuple[int, float]) -> float: + """获取Q值""" + if state not in self.Q_table: + self.Q_table[state] = {a: 0.0 for a in self.config.ACTIONS} + return self.Q_table[state][action] -def load_initial_solution(file_path): + def set_Q(self, state: Tuple[float, float, float], action: Tuple[int, float], value: float) -> None: + """设置Q值""" + if state not in self.Q_table: + self.Q_table[state] = {a: 0.0 for a in self.config.ACTIONS} + self.Q_table[state][action] = value + + def choose_action(self, state: Tuple[float, float, float], epsilon: float) -> Tuple[int, float]: + """选择动作(ε-greedy策略)""" + if random.random() < epsilon: + return random.choice(self.config.ACTIONS) + if state not in self.Q_table: + self.Q_table[state] = {a: 0.0 for a in self.config.ACTIONS} + return max(self.Q_table[state].items(), key=lambda x: x[1])[0] + + +def load_initial_solution(file_path: str) -> Tuple[List[float], List[float], List[List[int]]]: """ - 从 JSON 文件加载初始解 - :param file_path: JSON 文件路径 - :return: 行切分比例、列切分比例 + 从JSON文件加载初始解 + + Args: + file_path: JSON文件路径 + + Returns: + Tuple[List[float], List[float], List[List[int]]]: (行切分比例, 列切分比例, 车辆路径) """ with open(file_path, 'r', encoding='utf-8') as file: data = json.load(file) - row_cuts = data['row_boundaries'] - col_cuts = data['col_boundaries'] - car_paths = data['car_paths'] - return row_cuts, col_cuts, car_paths + return data['row_boundaries'], data['col_boundaries'], data['car_paths'] -if __name__ == "__main__": + +def main(): + """主函数""" + # 初始化配置 + config = Config() random.seed(42) - # --------------------------- - # 需要修改的超参数 - # --------------------------- + # 加载初始解 solution_path = r"solutions\trav_ga_params2_parallel.json" params_file = r"params2" initial_row_cuts, initial_col_cuts, car_paths = load_initial_solution( solution_path) - initial_state = (0.2, 0.4, 0.7) - # Q-learning 主循环 - env = FunctionEnv(initial_state) + initial_state = initial_row_cuts[1:-1] + initial_col_cuts[1:-1] + cut_index = len(initial_row_cuts) - 2 - for episode in range(NUM_EPISODES): - print(f"Episode {episode + 1} of {NUM_EPISODES}") + # 初始化环境和Q-learning + env = FunctionEnv(params_file, initial_state, cut_index, car_paths, config) + q_learning = QLearning(config) + + # 训练循环 + for episode in range(config.NUM_EPISODES): + print(f"Episode {episode + 1} of {config.NUM_EPISODES}") state = env.reset(initial_state) done = False - + while not done: - # 选择动作 - action = choose_action(state, EPSILON) - # 环境执行动作 + action = q_learning.choose_action(tuple(state), config.EPSILON) next_state, reward, done = env.step(action) - # Q-learning 更新:Q(s,a) = Q(s,a) + α [r + γ * max_a' Q(s', a') - Q(s,a)] - if next_state not in Q_table: - Q_table[next_state] = {a: 0.0 for a in ACTIONS} - max_next_Q = max(Q_table[next_state].values()) - current_Q = get_Q(state, action) - new_Q = current_Q + ALPHA * (reward + GAMMA * max_next_Q - current_Q) - set_Q(state, action, new_Q) + next_state = tuple(next_state) + + # Q-learning更新 + if next_state not in q_learning.Q_table: + q_learning.Q_table[next_state] = { + a: 0.0 for a in config.ACTIONS} + max_next_Q = max(q_learning.Q_table[next_state].values()) + current_Q = q_learning.get_Q(tuple(state), action) + new_Q = current_Q + config.ALPHA * \ + (reward + config.GAMMA * max_next_Q - current_Q) + q_learning.set_Q(tuple(state), action, new_Q) state = next_state - # 可逐步减小探索率 - EPSILON = max(0.01, EPSILON * 0.999) + # 更新探索率 + config.EPSILON = max(0.01, config.EPSILON * 0.999) - # 输出 Q-table 中最佳策略的状态和值 + # 输出最优解 best_state = None best_value = float('inf') - for state in Q_table: - # 这里根据函数值来评价解的好坏 - state_value = f(state) + for state in q_learning.Q_table: + state = list(state) + state_value = env.calculate_objective( + state[:cut_index], state[cut_index:]) if state_value < best_value: best_value = state_value best_state = state print("找到的最优状态:", best_state, "对应函数值:", best_value) + + +if __name__ == "__main__": + main()