diff --git a/GA/ga.py b/GA/ga.py index 336d414..0c31d22 100644 --- a/GA/ga.py +++ b/GA/ga.py @@ -1,7 +1,7 @@ import math import random -import matplotlib.pyplot as plt +# import matplotlib.pyplot as plt import numpy as np @@ -15,7 +15,7 @@ class GA(object): self.location = data self.to_process_idx = to_process_idx self.rectangles = rectangles - self.epochs = 3000 + self.epochs = 500 self.ga_choose_ratio = 0.2 self.mutate_ratio = 0.05 # fruits中存每一个个体是下标的list diff --git a/GA/main_parallel.py b/GA/main_parallel.py index f3915d9..ce83f4c 100644 --- a/GA/main_parallel.py +++ b/GA/main_parallel.py @@ -30,10 +30,10 @@ if __name__ == "__main__": # 重要:在 Windows 上必须加这一行 # --------------------------- # 需要修改的超参数 # --------------------------- - R = 3 - C = 3 - params_file = 'params2' - batch_size = 60 # 控制一次最多并行多少个任务 + R = 1 + C = 1 + params_file = 'params3' + batch_size = 10 # 控制一次最多并行多少个任务 with open(params_file + '.yml', 'r', encoding='utf-8') as file: params = yaml.safe_load(file) diff --git a/env_partion.py b/env_partion.py index 32f3ce4..a5ac212 100644 --- a/env_partion.py +++ b/env_partion.py @@ -4,6 +4,7 @@ import numpy as np import yaml import math from mTSP_solver import mTSP +from GA.ga import GA class PartitionEnv(gym.Env): @@ -17,11 +18,11 @@ class PartitionEnv(gym.Env): ############################## # 可能需要手动修改的超参数 ############################## - self.params = 'params2' - self.CUT_NUM = 4 - self.ROW_CUT_LIMIT = 3 + self.params = 'params3' + self.CUT_NUM = 2 + self.ROW_CUT_LIMIT = 1 self.COL_CUT_LIMIT = 1 - self.BASE_LINE = 12000 + self.BASE_LINE = 5000 self.mTSP_STEPS = 10000 # 车队参数设置 @@ -138,22 +139,41 @@ class PartitionEnv(gym.Env): state = self.partition_values # 继续进行路径规划 + # 使用q_learning解多旅行商 # cities: [[x1, x2, x3...], [y1, y2, y3...]] 城市坐标 - rec_center_lt = [rec_info['center'] - for rec_info in self.rectangles] - cities = np.column_stack(rec_center_lt) - cities = np.column_stack((self.center, cities)) + # rec_center_lt = [rec_info['center'] + # for rec_info in self.rectangles] + # cities = np.column_stack(rec_center_lt) + # cities = np.column_stack((self.center, cities)) - center_idx = [] + # center_idx = [] + # for i in range(self.num_cars - 1): + # cities = np.column_stack((cities, self.center)) + # center_idx.append(cities.shape[1] - 1) + + # tsp = mTSP(params=self.params, num_cities=cities.shape[1], cities=cities, num_cars=self.num_cars, + # center_idx=center_idx, rectangles=self.rectangles) + + # best_time, best_path = tsp.train(self.mTSP_STEPS) + + # 使用遗传算法解多旅行商 + cities = [self.center] + for rec in self.rectangles: + cities.append(rec['center']) + cities = np.array(cities) + + center_idx = [0] for i in range(self.num_cars - 1): - cities = np.column_stack((cities, self.center)) - center_idx.append(cities.shape[1] - 1) + cities = np.row_stack((cities, self.center)) + center_idx.append(cities.shape[0] - 1) - tsp = mTSP(params=self.params, num_cities=cities.shape[1], cities=cities, num_cars=self.num_cars, - center_idx=center_idx, rectangles=self.rectangles) - best_time, best_path = tsp.train(self.mTSP_STEPS) - print(best_time) - print(best_path) + ga = GA(num_drones=self.num_cars, num_city=cities.shape[0], num_total=20, + data=cities, to_process_idx=center_idx, rectangles=self.rectangles) + + best_path, best_time = ga.run() + + # print(best_time) + # print(best_path) reward += self.BASE_LINE - best_time diff --git a/human_action.py b/human_action.py index b356699..9e3fe34 100644 --- a/human_action.py +++ b/human_action.py @@ -11,6 +11,7 @@ print('state:', state) # action_series = [[0.67], [0], [0], [0], [0.7]] # action_series = [0, 0, 3, 0, 10] action_series = [[0.2], [0.4], [0.7], [0.5]] +# action_series = [[0.5], [0.5]] for i in range(100): action = action_series[i] diff --git a/mTSP_solver.py b/mTSP_solver.py index 6e7e12b..8ed92d6 100644 --- a/mTSP_solver.py +++ b/mTSP_solver.py @@ -15,9 +15,9 @@ class mTSP(object): num_cars=2, center_idx=[0], rectangles=None, - alpha=2, - beta=1, - learning_rate=0.001, + alpha=1, + beta=4, + learning_rate=0.01, eps=0.1, ): ''' diff --git a/solutions/trav_ga_params2_parallel.json b/solutions/trav_ga_params2_parallel.json index 95c5b6c..7ffa6a9 100644 --- a/solutions/trav_ga_params2_parallel.json +++ b/solutions/trav_ga_params2_parallel.json @@ -1,7 +1,7 @@ { "row_boundaries": [ 0.0, - 0.2, + 0.1, 0.4, 0.7, 1.0 @@ -18,13 +18,13 @@ 5 ], [ - 4, + 0, 2, - 0 + 4 ], [ - 7, - 6 + 6, + 7 ] ] } \ No newline at end of file diff --git a/solutions/trav_ga_params3_parallel.json b/solutions/trav_ga_params3_parallel.json index 929d6da..94bc288 100644 --- a/solutions/trav_ga_params3_parallel.json +++ b/solutions/trav_ga_params3_parallel.json @@ -11,8 +11,8 @@ ], "car_paths": [ [ - 2, - 0 + 0, + 2 ], [ 3,