diff --git a/Duel_Double_DQN/main.py b/Duel_Double_DQN/main.py index 634a00b..e5b2417 100644 --- a/Duel_Double_DQN/main.py +++ b/Duel_Double_DQN/main.py @@ -142,7 +142,7 @@ def main(): if total_steps % 1000 == 0: agent.exp_noise *= opt.noise_decay if total_steps % opt.eval_interval == 0: - score = evaluate_policy(eval_env, agent, turns=3) + score = evaluate_policy(eval_env, agent, turns=1) if opt.write: writer.add_scalar( 'ep_r', score, global_step=total_steps) diff --git a/Duel_Double_DQN/utils.py b/Duel_Double_DQN/utils.py index 5d89f0d..b37e84c 100644 --- a/Duel_Double_DQN/utils.py +++ b/Duel_Double_DQN/utils.py @@ -1,31 +1,88 @@ -def evaluate_policy(env, agent, turns = 3): +import json +from datetime import datetime +import copy + + +def evaluate_policy(env, agent, turns=3): + """ + 评估策略 + Args: + env: 环境对象 + agent: 智能体对象 + turns: 评估轮数 + Returns: + int: 平均得分 + """ total_scores = 0 - for j in range(turns): - s = env.reset() - done = False - action_series = [] - while not done: - # Take deterministic actions at test time - a = agent.select_action(s, deterministic=True) - s_next, r, dw, tr, info = env.step(a) - done = (dw or tr) - action_series.append(a) - total_scores += r - s = s_next - print('action series: ', action_series) - print('state: ', s) + + # for j in range(turns): + s = env.reset() + done = False + eval_info = {'action_series': [], + # 'state_series': [], + 'reward_series': []} + info_lt = [] + + while not done: + a = agent.select_action(s, deterministic=True) + s_next, r, dw, tr, info = env.step(a) + done = (dw or tr) + + eval_info['action_series'].append(a) + eval_info['reward_series'].append(r) + info_lt.append(copy.deepcopy(info)) + + total_scores += r + s = s_next + + print(eval_info) + save_best_solution(info_lt) + return int(total_scores/turns) -#You can just ignore this funciton. Is not related to the RL. +def save_best_solution(info_lt): + # 找出这一轮中最优的解 + best_info = min(info_lt, key=lambda x: x['best_time']) + + # 读取已有的最优解 + try: + with open('solutions/dqn_params_50_50_3.json', 'r') as f: + saved_solution = json.load(f) + saved_time = saved_solution['best_time'] + except FileNotFoundError: + saved_time = float('inf') + + # 如果新的解更好,则更新json文件 + if best_info['best_time'] < saved_time: + best_solution = { + 'best_time': best_info['best_time'], + 'row_cuts': best_info['row_cuts'], + 'col_cuts': best_info['col_cuts'], + 'best_path': best_info['best_path'], + 'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S') + } + + with open('solutions/dqn_params_50_50_3.json', 'w') as f: + json.dump(best_solution, f, indent=4) + + print(f"发现新的最优解!时间: {best_info['best_time']}") + + +def compare_lists(list1, list2): + return len(list1) == len(list2) and all(a == b for a, b in zip(list1, list2)) + +# You can just ignore this funciton. Is not related to the RL. + + def str2bool(v): '''transfer str to bool for argparse''' if isinstance(v, bool): return v - if v.lower() in ('yes', 'True','true','TRUE', 't', 'y', '1'): + if v.lower() in ('yes', 'True', 'true', 'TRUE', 't', 'y', '1'): return True - elif v.lower() in ('no', 'False','false','FALSE', 'f', 'n', '0'): + elif v.lower() in ('no', 'False', 'false', 'FALSE', 'f', 'n', '0'): return False else: print('Wrong Input.') - raise \ No newline at end of file + raise diff --git a/GA/ga.py b/GA/ga.py index 9bc6fd6..9cd036a 100644 --- a/GA/ga.py +++ b/GA/ga.py @@ -16,7 +16,7 @@ class GA(object): self.location = data self.to_process_idx = to_process_idx self.rectangles = rectangles - self.epochs = 1000 + self.epochs = 1500 self.ga_choose_ratio = 0.2 self.mutate_ratio = 0.05 # fruits中存每一个个体是下标的list diff --git a/env_partion_dist.py b/env_partion_dist.py index bb6bf56..6aa70a4 100644 --- a/env_partion_dist.py +++ b/env_partion_dist.py @@ -18,11 +18,11 @@ class PartitionEnv(gym.Env): ############################## # 可能需要手动修改的超参数 ############################## - self.params = 'params2' + self.params = 'params_50_50_3' self.ORI_ROW_CUTS = [0, 0.2, 0.4, 0.7, 1] self.ORI_COL_CUTS = [0, 0.5, 1] self.CUT_NUM = 4 - self.BASE_LINE = 9100 + self.BASE_LINE = 9051.16 self.MAX_ADJUST_STEP = 50 self.ADJUST_THRESHOLD = 0.1 # self.mTSP_STEPS = 10000 @@ -115,11 +115,13 @@ class PartitionEnv(gym.Env): reward = self.calc_reward(best_time) self.adjust_step += 1 state = np.array(self.row_cuts + self.col_cuts) + info = {'row_cuts': self.row_cuts, 'col_cuts': self.col_cuts, + 'best_path': self.best_path, 'best_time': best_time} if self.adjust_step < self.MAX_ADJUST_STEP: - return state, reward, False, False, {} + return state, reward, False, False, info else: - return state, reward, True, False, {} + return state, reward, True, False, info def if_valid_partition(self): rectangles = [] @@ -221,7 +223,11 @@ class PartitionEnv(gym.Env): def calc_reward(self, best_time): """ - 计算奖励 + 计算奖励: + 1. 如果时间小于基准线,给予正奖励 + 2. 如果时间大于基准线,给予负奖励 + 3. 保持归一化和折扣因子 + Args: best_time (float): 当前路径的时间 Returns: @@ -229,14 +235,15 @@ class PartitionEnv(gym.Env): """ time_diff = self.BASE_LINE - best_time - # 归一化时间差 - normalized_diff = 1 / (1 + np.exp(-time_diff/20)) + # 使用tanh归一化,确保time_diff=0时,normalized_diff=0 + # tanh在变量值为2时,就非常接近1了。最大的time_diff为400 + normalized_diff = np.tanh(time_diff / 200) # 20是缩放因子,可调整 - # 计算轮次权重 + # 计算轮次权重(折扣因子) step_weight = 1 / (1 + np.exp(-self.adjust_step/10)) - # 计算最终奖励(添加缩放因子) - reward = normalized_diff * step_weight * 10 # 10是缩放因子 + # 计算最终奖励 + reward = normalized_diff * step_weight # 10是缩放因子 return reward diff --git a/mtkl_sovler2.py b/mtkl_sovler2.py index ab39f08..c25a00a 100644 --- a/mtkl_sovler2.py +++ b/mtkl_sovler2.py @@ -12,7 +12,7 @@ random.seed(42) # --------------------------- # 需要修改的超参数 # --------------------------- -num_iterations = 100000000 +num_iterations = 3000000000 # 随机生成分区的行分段数与列分段数 R = random.randint(0, 3) # 行分段数 C = random.randint(0, 3) # 列分段数 @@ -47,13 +47,15 @@ best_solution = None for iteration in tqdm(range(num_iterations), desc="蒙特卡洛模拟进度"): # 直接切值 - horiz = [random.random() for _ in range(R)] + # horiz = [random.random() for _ in range(R)] + horiz = [random.randint(1, 999)/1000 for _ in range(R)] horiz = sorted(set(horiz)) horiz = horiz if horiz else [] row_boundaries = [0] + horiz + [1] row_boundaries = [boundary * H for boundary in row_boundaries] - vert = [random.random() for _ in range(C)] + # vert = [random.random() for _ in range(C)] + vert = [random.randint(1, 999)/1000 for _ in range(C)] vert = sorted(set(vert)) vert = vert if vert else [] col_boundaries = [0] + vert + [1] @@ -151,8 +153,6 @@ for iteration in tqdm(range(num_iterations), desc="蒙特卡洛模拟进度"): T_max = max(T_k_list) # 整体目标 T 为各系统中最大的 T_k - # TODO 没有限制系统的总能耗 - if T_max < best_T: best_T = T_max best_solution = { @@ -168,6 +168,7 @@ for iteration in tqdm(range(num_iterations), desc="蒙特卡洛模拟进度"): 'flight_time': total_flight_time, 'bs_time': total_bs_time } + print(iteration) # --------------------------- # 输出最佳方案 diff --git a/solutions/dqn_params_50_50_3.json b/solutions/dqn_params_50_50_3.json new file mode 100644 index 0000000..af7e7ba --- /dev/null +++ b/solutions/dqn_params_50_50_3.json @@ -0,0 +1,29 @@ +{ + "best_time": 9051.162633521315, + "row_cuts": [ + 0, + 0.21000000000000002, + 0.4, + 0.7, + 1 + ], + "col_cuts": [ + 0, + 0.5, + 1 + ], + "best_path": [ + 7, + 8, + 0, + 6, + 2, + 4, + 10, + 9, + 5, + 3, + 1 + ], + "timestamp": "2025-04-01 17:43:22" +} \ No newline at end of file diff --git a/visualization.py b/visualization.py index 3165558..1bb8edf 100644 --- a/visualization.py +++ b/visualization.py @@ -14,8 +14,8 @@ def visualize_solution(row_boundaries, col_boundaries, car_paths_coords, W, H, r # 设置英文标题和标签 # ax.set_title("Monte Carlo", fontsize=12) - ax.set_title("Greedy", fontsize=12) - # ax.set_title("Enumeration-Genetic Algorithm", fontsize=12) + # ax.set_title("Greedy", fontsize=12) + ax.set_title("Enumeration-Genetic Algorithm", fontsize=12) # ax.set_title("DQN fine-tuning", fontsize=12) ax.set_xlabel("Region Width", fontsize=10) @@ -200,7 +200,7 @@ if __name__ == "__main__": # 需要修改的超参数 # --------------------------- params_file = 'params_50_50_3' - solution_file = r'solutions\greedy_params_50_50_3.json' + solution_file = r'solutions\trav_ga_params_50_50_3_parallel.json' with open(params_file + '.yml', 'r', encoding='utf-8') as file: params = yaml.safe_load(file)