From d64ec83042fd487a8a39aecddabe5bba4cedc27d Mon Sep 17 00:00:00 2001 From: weixin_46229132 Date: Tue, 8 Apr 2025 15:49:22 +0800 Subject: [PATCH] =?UTF-8?q?=E5=8F=AF=E8=A7=86=E5=8C=96reward?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- GA/ga.py | 14 ++--- env_partion_dist.py | 20 +++---- human_action.py | 2 +- plot_dqn_training.py | 40 ++++++++++++++ solutions/dqn_params_100_100_6.json | 12 ++--- solutions/greedy_params_100_100_5.json | 73 -------------------------- visualization.py | 6 +-- 7 files changed, 67 insertions(+), 100 deletions(-) create mode 100644 plot_dqn_training.py delete mode 100644 solutions/greedy_params_100_100_5.json diff --git a/GA/ga.py b/GA/ga.py index 3b212c0..8add8dc 100644 --- a/GA/ga.py +++ b/GA/ga.py @@ -42,14 +42,14 @@ class GA(object): # print("Lens:", len(result), len(result[0])) return result - def greedy_init(self, dis_mat, num_total, num_city): + def greedy_init(self): start_index = 0 result = [] - for i in range(num_total): - rest = [x for x in range(0, num_city)] + for i in range(self.num_total): + rest = [x for x in range(0, self.num_city)] # 所有起始点都已经生成了 - if start_index >= num_city: - start_index = np.random.randint(0, num_city) + if start_index >= self.num_city: + start_index = np.random.randint(0, self.num_city) result.append(result[start_index].copy()) continue current = start_index @@ -61,8 +61,8 @@ class GA(object): tmp_choose = -1 for x in rest: # print("---", current, x, dis_mat[current][x]) - if dis_mat[current][x] < tmp_min: - tmp_min = dis_mat[current][x] + if self.dis_mat[current][x] < tmp_min: + tmp_min = self.dis_mat[current][x] tmp_choose = x if tmp_choose == -1: # 此种情况仅可能发生在剩的都是基地点 tmp_choose = rest[0] diff --git a/env_partion_dist.py b/env_partion_dist.py index c4d5ab4..f8c298b 100644 --- a/env_partion_dist.py +++ b/env_partion_dist.py @@ -19,11 +19,11 @@ class PartitionEnv(gym.Env): # 可能需要手动修改的超参数 ############################## self.params = 'params_100_100_6' - self.ORI_ROW_CUTS = [0, 0.2, 0.4, 0.6, 0.8, 1] + self.ORI_ROW_CUTS = [0, 0.28, 0.43, 0.62, 0.77, 1] self.ORI_COL_CUTS = [0, 0.2, 0.4, 0.5, 0.7, 0.8, 1] - self.CUT_NUM = 9 - self.BASE_LINE = 19757.42 - self.MAX_ADJUST_STEP = 80 + self.CUT_NUM = 5 + self.BASE_LINE = 19376.06 + self.MAX_ADJUST_STEP = 50 # self.ADJUST_THRESHOLD = 0.1 # self.mTSP_STEPS = 10000 @@ -31,7 +31,7 @@ class PartitionEnv(gym.Env): self.action_space = spaces.Discrete(self.CUT_NUM*2 + 1) # 定义观察空间为8维向量 self.observation_space = spaces.Box( - low=0.0, high=1.0, shape=(self.CUT_NUM + 4,), dtype=np.float32) + low=0.0, high=1.0, shape=(len(self.ORI_ROW_CUTS)+len(self.ORI_COL_CUTS),), dtype=np.float32) self.row_cuts = self.ORI_ROW_CUTS[:] self.col_cuts = self.ORI_COL_CUTS[:] @@ -94,16 +94,16 @@ class PartitionEnv(gym.Env): cut_index, signal = (action + 1) // 2, (action + 1) % 2 if action == 0: pass - elif cut_index <= 4: + elif cut_index <= 5: if signal == 0: - self.row_cuts[cut_index] += 0.01 + self.col_cuts[cut_index] += 0.005 else: - self.row_cuts[cut_index] -= 0.01 + self.col_cuts[cut_index] -= 0.005 else: if signal == 0: - self.col_cuts[cut_index-4] += 0.01 + self.col_cuts[cut_index-4] += 0.005 else: - self.col_cuts[cut_index-4] -= 0.01 + self.col_cuts[cut_index-4] -= 0.005 # 检查row_cuts和col_cuts是否按升序排列 if (all(self.row_cuts[i] < self.row_cuts[i+1] for i in range(len(self.row_cuts)-1)) and diff --git a/human_action.py b/human_action.py index 4eab688..46a6e50 100644 --- a/human_action.py +++ b/human_action.py @@ -13,7 +13,7 @@ print('state:', state) # action_series = [1] * 30 # action_series = [[0.2], [0.4], [0.7], [0.5]] # action_series = [[-0.08], [-0.08], [0], [0]] -action_series = list(range(19)) +action_series = list(range(11)) for i in range(100): action = action_series[i] diff --git a/plot_dqn_training.py b/plot_dqn_training.py new file mode 100644 index 0000000..05c0ead --- /dev/null +++ b/plot_dqn_training.py @@ -0,0 +1,40 @@ +import matplotlib.pyplot as plt +import pandas as pd +import numpy as np + +def plot_dqn_training_curve(csv_file): + # 读取CSV文件 + data = pd.read_csv(csv_file) + + # 检查是否包含所需列 + if 'Step' not in data.columns or 'Value' not in data.columns: + raise ValueError("CSV文件必须包含'step'和'value'列") + + # 提取数据 + steps = data['Step'] + rewards = data['Value'] + + # Plot the curve + plt.plot(steps, rewards, label='Reward', color='blue') + + # Set title and axis labels + # plt.title("DQN Training Curve", fontsize=16) + plt.xlabel("Training Steps (w)", fontsize=14) + plt.ylabel("Reward", fontsize=14) + + # Adjust x-axis ticks dynamically based on step range + step_min, step_max = steps.min(), steps.max() + tick_interval = (step_max - step_min) // 10 # Divide into 10 intervals + ticks = np.arange(step_min, step_max + tick_interval, tick_interval) + plt.xticks(ticks=ticks, labels=[f"{x//10000}w" for x in ticks]) + + # Add grid and legend + plt.grid(True, linestyle='--', alpha=0.7) + + # Show the plot + plt.show() + +if __name__ == "__main__": + # 替换为你的CSV文件路径 + csv_file = r"runs\DQN-PartEnv_S0_ 2025-04-02 20_13\DQN-PartEnv_S0_ 2025-04-02 20_13.csv" + plot_dqn_training_curve(csv_file) diff --git a/solutions/dqn_params_100_100_6.json b/solutions/dqn_params_100_100_6.json index 5211c15..736ce88 100644 --- a/solutions/dqn_params_100_100_6.json +++ b/solutions/dqn_params_100_100_6.json @@ -1,11 +1,11 @@ { - "best_time": 19643.795059416032, + "best_time": 19376.05694186515, "row_cuts": [ 0, - 0.2, - 0.4, - 0.6, - 0.78, + 0.2800000000000001, + 0.43000000000000005, + 0.62, + 0.77, 1 ], "col_cuts": [ @@ -56,5 +56,5 @@ 19, 0 ], - "timestamp": "2025-04-05 11:03:20" + "timestamp": "2025-04-06 09:10:53" } \ No newline at end of file diff --git a/solutions/greedy_params_100_100_5.json b/solutions/greedy_params_100_100_5.json deleted file mode 100644 index b066e1b..0000000 --- a/solutions/greedy_params_100_100_5.json +++ /dev/null @@ -1,73 +0,0 @@ -{ - "row_boundaries": [ - 0.0, - 0.2, - 0.4, - 0.6000000000000001, - 0.8, - 1.0 - ], - "col_boundaries": [ - 0, - 0.12533333333333332, - 0.25066666666666665, - 0.376, - 0.5013333333333333, - 0.6266666666666666, - 0.7519999999999999, - 0.8773333333333332, - 1 - ], - "car_paths": [ - [ - 0, - 1, - 2, - 3, - 4, - 5, - 6, - 7 - ], - [ - 8, - 9, - 10, - 11, - 12, - 13, - 14, - 15 - ], - [ - 16, - 17, - 18, - 19, - 20, - 21, - 22, - 23 - ], - [ - 24, - 25, - 26, - 27, - 28, - 29, - 30, - 31 - ], - [ - 32, - 33, - 34, - 35, - 36, - 37, - 38, - 39 - ] - ] -} \ No newline at end of file diff --git a/visualization.py b/visualization.py index e183909..859a878 100644 --- a/visualization.py +++ b/visualization.py @@ -14,9 +14,9 @@ def visualize_solution(row_boundaries, col_boundaries, car_paths_coords, W, H, r # 设置英文标题和标签 # ax.set_title("Monte Carlo", fontsize=12) - ax.set_title("Greedy", fontsize=12) + # ax.set_title("Greedy", fontsize=12) # ax.set_title("Enumeration-Genetic Algorithm", fontsize=12) - # ax.set_title("DQN fine-tuning", fontsize=12) + ax.set_title("DQN fine-tuning", fontsize=12) ax.set_xlabel("Region Width", fontsize=10) ax.set_ylabel("Region Height", fontsize=10) @@ -200,7 +200,7 @@ if __name__ == "__main__": # 需要修改的超参数 # --------------------------- params_file = 'params_100_100_6' - solution_file = r'solutions\greedy_params_100_100_6.json' + solution_file = r'solutions\finetune_params_100_100_6.json' with open(params_file + '.yml', 'r', encoding='utf-8') as file: params = yaml.safe_load(file)