From 6f8fcd15b7dd76e9ebdb629b8e4272cce26731e0 Mon Sep 17 00:00:00 2001 From: weixin_46229132 Date: Thu, 27 Mar 2025 20:50:46 +0800 Subject: [PATCH] =?UTF-8?q?=E5=8A=A0=E5=85=A5q=20learning?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- GA/main.py | 2 +- GA/main_parallel.py | 6 +- GA/sa_finetune.py | 23 ++- Q_learning/q_table.py | 234 ++++++++++++++++++++++++ mtkl_sovler.py | 4 +- solutions/finetune_params2.json | 30 +++ solutions/mtkl_params2.json | 22 +-- solutions/trav_ga_params2.json | 30 +++ solutions/trav_ga_params2_parallel.json | 12 +- visualization.py | 2 +- 10 files changed, 339 insertions(+), 26 deletions(-) create mode 100644 Q_learning/q_table.py create mode 100644 solutions/finetune_params2.json create mode 100644 solutions/trav_ga_params2.json diff --git a/GA/main.py b/GA/main.py index dee6286..b273801 100644 --- a/GA/main.py +++ b/GA/main.py @@ -20,7 +20,7 @@ best_col_boundaries = None # --------------------------- R = 3 C = 3 -params_file = 'params3' +params_file = 'params2' with open(params_file + '.yml', 'r', encoding='utf-8') as file: diff --git a/GA/main_parallel.py b/GA/main_parallel.py index 0ad0363..f3915d9 100644 --- a/GA/main_parallel.py +++ b/GA/main_parallel.py @@ -30,9 +30,9 @@ if __name__ == "__main__": # 重要:在 Windows 上必须加这一行 # --------------------------- # 需要修改的超参数 # --------------------------- - R = 1 - C = 1 - params_file = 'params3' + R = 3 + C = 3 + params_file = 'params2' batch_size = 60 # 控制一次最多并行多少个任务 with open(params_file + '.yml', 'r', encoding='utf-8') as file: diff --git a/GA/sa_finetune.py b/GA/sa_finetune.py index b5e1db6..3ba46ec 100644 --- a/GA/sa_finetune.py +++ b/GA/sa_finetune.py @@ -16,7 +16,8 @@ class SA_FineTuner: :param cooling_rate: 温度下降速率 """ # 读取参数 - with open(params_file + '.yml', 'r', encoding='utf-8') as file: + self.params_file = params_file + with open(self.params_file + '.yml', 'r', encoding='utf-8') as file: params = yaml.safe_load(file) self.H = params['H'] @@ -91,6 +92,21 @@ class SA_FineTuner: delta = new_T - current_T acceptance_probability = math.exp(-delta / temperature) return random.random() < acceptance_probability + + def save_best_solution(self, row_cuts, col_cuts, car_paths): + """ + 保存最佳方案 + :param row_cuts: 行切分比例 + :param col_cuts: 列切分比例 + :param car_paths: 车队路径 + """ + output_data = { + 'row_boundaries': row_cuts, + 'col_boundaries': col_cuts, + 'car_paths': car_paths + } + with open(f'./solutions/finetune_{self.params_file}.json', 'w', encoding='utf-8') as file: + json.dump(output_data, file, ensure_ascii=False, indent=4) def T_function(self, row_cuts, col_cuts): """ @@ -181,6 +197,9 @@ class SA_FineTuner: if iteration % 100 == 0: print( f"Iteration {iteration}: Best T = {self.best_T}, Temperature = {self.temperature}") + + # 保存最佳方案 + self.save_best_solution(self.best_row_cuts, self.best_col_cuts, self.car_paths) return self.best_row_cuts, self.best_col_cuts, self.best_T @@ -206,7 +225,7 @@ if __name__ == "__main__": # --------------------------- # 需要修改的超参数 # --------------------------- - solution_path = r"solutions\mtkl_params2.json" + solution_path = r"solutions\trav_ga_params2_parallel.json" params_file = r"params2" max_iterations=10000 initial_temp=100 diff --git a/Q_learning/q_table.py b/Q_learning/q_table.py new file mode 100644 index 0000000..44fffd1 --- /dev/null +++ b/Q_learning/q_table.py @@ -0,0 +1,234 @@ +import random +import numpy as np +import json +import math +import yaml +# 参数设置 +STEP = 0.01 +VALUES = [round(i*STEP, 2) for i in range(101)] # 0.00~1.00 +ACTION_DELTA = [STEP, -STEP] # 增加或减少 0.01 +ACTIONS = [] # 每个动作为 (var_index, delta) +for i in range(3): + for delta in ACTION_DELTA: + ACTIONS.append((i, delta)) + +ALPHA = 0.1 # 学习率 +GAMMA = 0.9 # 折扣因子 +EPSILON = 0.2 # 探索率 +NUM_EPISODES = 100 + +def f(state): + """ + 计算切分比例的目标值 T(占位函数) + :param row_cuts: 行切分比例 + :param col_cuts: 列切分比例 + :return: 目标值 T + """ + with open('params2.yml', 'r', encoding='utf-8') as file: + params = yaml.safe_load(file) + + H = params['H'] + W = params['W'] + num_cars = params['num_cars'] + + flight_time_factor = params['flight_time_factor'] + comp_time_factor = params['comp_time_factor'] + trans_time_factor = params['trans_time_factor'] + car_time_factor = params['car_time_factor'] + bs_time_factor = params['bs_time_factor'] + + flight_energy_factor = params['flight_energy_factor'] + comp_energy_factor = params['comp_energy_factor'] + trans_energy_factor = params['trans_energy_factor'] + battery_energy_capacity = params['battery_energy_capacity'] + + col_cuts = list(state) + col_cuts.insert(0, 0) + col_cuts.append(1) + row_cuts = [0, 0.5, 1] + rectangles = [] + for i in range(len(row_cuts) - 1): + for j in range(len(col_cuts) - 1): + d = (col_cuts[j+1] - col_cuts[j]) * W * \ + (row_cuts[i+1] - row_cuts[i]) * H + rho_time_limit = (flight_time_factor - trans_time_factor) / \ + (comp_time_factor - trans_time_factor) + rho_energy_limit = (battery_energy_capacity - flight_energy_factor * d - trans_energy_factor * d) / (comp_energy_factor * d - trans_energy_factor * d) + if rho_energy_limit < 0: + return 100000 + rho = min(rho_time_limit, rho_energy_limit) + + flight_time = flight_time_factor * d + bs_time = bs_time_factor * (1 - rho) * d + + rectangles.append({ + 'flight_time': flight_time, + 'bs_time': bs_time, + 'center': ((row_cuts[i] + row_cuts[i+1]) / 2.0 * H, + (col_cuts[j] + col_cuts[j+1]) / 2.0 * W) + }) + + mortorcade_time_lt = [] + for idx in range(num_cars): + car_path = car_paths[idx] + + flight_time = sum(rectangles[point]['flight_time'] + for point in car_path) + bs_time = sum(rectangles[point]['bs_time'] for point in car_path) + + car_time = 0 + for i in range(len(car_path) - 1): + first_point = car_path[i] + second_point = car_path[i + 1] + car_time += math.dist( + rectangles[first_point]['center'], rectangles[second_point]['center']) * car_time_factor + car_time += math.dist(rectangles[car_path[0]]['center'], + [H / 2, W / 2]) * car_time_factor + car_time += math.dist(rectangles[car_path[-1]]['center'], + [H / 2, W / 2]) * car_time_factor + mortorcade_time_lt.append(max(car_time + flight_time, bs_time)) + + return max(mortorcade_time_lt) + +# 环境类:定义状态转移与奖励 +class FunctionEnv: + def __init__(self, initial_state): + self.state = initial_state # 初始状态 (x1,x2,x3) + self.best_value = float('inf') # 记录最佳值 + self.no_improvement_count = 0 # 记录连续未改善的次数 + self.last_state = None # 记录上一个状态 + self.min_improvement = 0.001 # 最小改善阈值 + self.max_no_improvement = 10 # 最大允许连续未改善次数 + self.target_threshold = 10000 # 目标函数值的可接受阈值 + + def step(self, action): + # action: (var_index, delta) + var_index, delta = action + new_state = list(self.state) + new_state[var_index] = round(new_state[var_index] + delta, 2) + # 保证取值在0-1范围内 + if new_state[var_index] < 0 or new_state[var_index] > 1: + return self.state, -10000.0, True # episode结束 + # 检查约束:x1 < x2 < x3 + if not (0 < new_state[0] < new_state[1] < new_state[2] < 1): + return self.state, -10000.0, True + + next_state = tuple(new_state) + current_value = f(next_state) + + # 检查是否达到目标阈值 + if current_value < self.target_threshold: + return next_state, 12000 - current_value, True + + # 检查状态变化是否很小 + if self.last_state is not None: + state_diff = sum(abs(a - b) for a, b in zip(next_state, self.last_state)) + if state_diff < self.min_improvement: + self.no_improvement_count += 1 + else: + self.no_improvement_count = 0 + + # 检查是否有改善 + if current_value < self.best_value: + self.best_value = current_value + self.no_improvement_count = 0 + else: + self.no_improvement_count += 1 + + # 如果连续多次没有改善,结束episode + if self.no_improvement_count >= self.max_no_improvement: + return next_state, 12000 - current_value, True + + self.last_state = next_state + self.state = next_state + return next_state, 12000 - current_value, False + + def reset(self, state): + self.state = state + return self.state + +# 初始化 Q-table:使用字典表示,key 为状态 tuple,value 为 dict: action->Q值 +Q_table = {} + +def get_Q(state, action): + if state not in Q_table: + Q_table[state] = {a: 0.0 for a in ACTIONS} + return Q_table[state][action] + +def set_Q(state, action, value): + if state not in Q_table: + Q_table[state] = {a: 0.0 for a in ACTIONS} + Q_table[state][action] = value + +def choose_action(state, epsilon): + # ε-greedy 策略 + if random.random() < epsilon: + return random.choice(ACTIONS) + else: + if state not in Q_table: + Q_table[state] = {a: 0.0 for a in ACTIONS} + # 返回Q值最大的动作 + return max(Q_table[state].items(), key=lambda x: x[1])[0] + +def load_initial_solution(file_path): + """ + 从 JSON 文件加载初始解 + :param file_path: JSON 文件路径 + :return: 行切分比例、列切分比例 + """ + with open(file_path, 'r', encoding='utf-8') as file: + data = json.load(file) + row_cuts = data['row_boundaries'] + col_cuts = data['col_boundaries'] + car_paths = data['car_paths'] + return row_cuts, col_cuts, car_paths + +if __name__ == "__main__": + random.seed(42) + + # --------------------------- + # 需要修改的超参数 + # --------------------------- + solution_path = r"solutions\trav_ga_params2_parallel.json" + params_file = r"params2" + + initial_row_cuts, initial_col_cuts, car_paths = load_initial_solution( + solution_path) + + initial_state = (0.2, 0.4, 0.7) + # Q-learning 主循环 + env = FunctionEnv(initial_state) + + for episode in range(NUM_EPISODES): + print(f"Episode {episode + 1} of {NUM_EPISODES}") + state = env.reset(initial_state) + done = False + + while not done: + # 选择动作 + action = choose_action(state, EPSILON) + # 环境执行动作 + next_state, reward, done = env.step(action) + # Q-learning 更新:Q(s,a) = Q(s,a) + α [r + γ * max_a' Q(s', a') - Q(s,a)] + if next_state not in Q_table: + Q_table[next_state] = {a: 0.0 for a in ACTIONS} + max_next_Q = max(Q_table[next_state].values()) + current_Q = get_Q(state, action) + new_Q = current_Q + ALPHA * (reward + GAMMA * max_next_Q - current_Q) + set_Q(state, action, new_Q) + state = next_state + + # 可逐步减小探索率 + EPSILON = max(0.01, EPSILON * 0.999) + + # 输出 Q-table 中最佳策略的状态和值 + best_state = None + best_value = float('inf') + for state in Q_table: + # 这里根据函数值来评价解的好坏 + state_value = f(state) + if state_value < best_value: + best_value = state_value + best_state = state + + print("找到的最优状态:", best_state, "对应函数值:", best_value) diff --git a/mtkl_sovler.py b/mtkl_sovler.py index 27a81dd..5af8701 100644 --- a/mtkl_sovler.py +++ b/mtkl_sovler.py @@ -11,12 +11,12 @@ random.seed(42) # --------------------------- # 需要修改的超参数 # --------------------------- -num_iterations = 10000 +num_iterations = 1000000 # 随机生成分区的行分段数与列分段数 # R = random.randint(0, 3) # 行分段数 # C = random.randint(0, 3) # 列分段数 R = 3 -C = 1 +C = 3 params_file = 'params2' diff --git a/solutions/finetune_params2.json b/solutions/finetune_params2.json new file mode 100644 index 0000000..59b0a55 --- /dev/null +++ b/solutions/finetune_params2.json @@ -0,0 +1,30 @@ +{ + "row_boundaries": [ + 0.0, + 0.3000000000000001, + 0.4800000000000001, + 0.77, + 1.0 + ], + "col_boundaries": [ + 0.0, + 0.5, + 1.0 + ], + "car_paths": [ + [ + 1, + 3, + 5 + ], + [ + 4, + 2, + 0 + ], + [ + 7, + 6 + ] + ] +} \ No newline at end of file diff --git a/solutions/mtkl_params2.json b/solutions/mtkl_params2.json index 2f4292c..cca8d17 100644 --- a/solutions/mtkl_params2.json +++ b/solutions/mtkl_params2.json @@ -1,30 +1,30 @@ { "row_boundaries": [ 0.0, - 0.3, - 0.4, - 0.7, + 0.5, 1.0 ], "col_boundaries": [ 0.0, + 0.2, 0.5, + 0.8, 1.0 ], "car_paths": [ [ 2, - 0 - ], - [ - 4, - 5, 3, - 1 + 7 ], [ - 6, - 7 + 1, + 5, + 6 + ], + [ + 0, + 4 ] ] } \ No newline at end of file diff --git a/solutions/trav_ga_params2.json b/solutions/trav_ga_params2.json new file mode 100644 index 0000000..cbbe99f --- /dev/null +++ b/solutions/trav_ga_params2.json @@ -0,0 +1,30 @@ +{ + "row_boundaries": [ + 0.0, + 0.1, + 0.4, + 0.7, + 1.0 + ], + "col_boundaries": [ + 0.0, + 0.5, + 1.0 + ], + "car_paths": [ + [ + 0, + 2, + 4 + ], + [ + 5, + 3, + 1 + ], + [ + 7, + 6 + ] + ] +} \ No newline at end of file diff --git a/solutions/trav_ga_params2_parallel.json b/solutions/trav_ga_params2_parallel.json index 101bee5..95c5b6c 100644 --- a/solutions/trav_ga_params2_parallel.json +++ b/solutions/trav_ga_params2_parallel.json @@ -1,7 +1,7 @@ { "row_boundaries": [ 0.0, - 0.1, + 0.2, 0.4, 0.7, 1.0 @@ -12,16 +12,16 @@ 1.0 ], "car_paths": [ + [ + 1, + 3, + 5 + ], [ 4, 2, 0 ], - [ - 5, - 3, - 1 - ], [ 7, 6 diff --git a/visualization.py b/visualization.py index 963d60a..d059d63 100644 --- a/visualization.py +++ b/visualization.py @@ -53,7 +53,7 @@ if __name__ == "__main__": # 需要修改的超参数 # --------------------------- params_file = 'params2' - solution_file = r'solutions\mtkl_params2.json' + solution_file = r'solutions\trav_finetune_params2.json' with open(params_file + '.yml', 'r', encoding='utf-8') as file: params = yaml.safe_load(file)