diff --git a/GA/ga.py b/GA/ga.py index fc796e3..3b212c0 100644 --- a/GA/ga.py +++ b/GA/ga.py @@ -16,7 +16,7 @@ class GA(object): self.location = data self.to_process_idx = to_process_idx self.rectangles = rectangles - self.epochs = 300 + self.epochs = 1000 self.ga_choose_ratio = 0.2 self.mutate_ratio = 0.05 # fruits中存每一个个体是下标的list @@ -314,7 +314,7 @@ class GA(object): early_stop_cnt = 0 else: early_stop_cnt += 1 - if early_stop_cnt == 50: # 若连续50次没有性能提升,则早停 + if early_stop_cnt == 100: # 若连续50次没有性能提升,则早停 break self.best_record.append(1.0 / best_score) best_length = 1.0 / best_score diff --git a/GA/main.py b/GA/main.py index 189bbb4..80d2fdf 100644 --- a/GA/main.py +++ b/GA/main.py @@ -3,7 +3,7 @@ import math import yaml import numpy as np from utils import if_valid_partition, GA_solver -from itertools import product +from itertools import product, combinations import json from tqdm import tqdm @@ -18,9 +18,9 @@ best_col_boundaries = None # --------------------------- # 需要修改的超参数 # --------------------------- -R = 3 -C = 3 -params_file = 'params_50_50_3' +R = 10 +C = 10 +params_file = 'params_100_100_5' with open(params_file + '.yml', 'r', encoding='utf-8') as file: @@ -41,20 +41,29 @@ comp_energy_factor = params['comp_energy_factor'] trans_energy_factor = params['trans_energy_factor'] battery_energy_capacity = params['battery_energy_capacity'] -# 定义数字列表 -numbers = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] +# # 定义数字列表 +# numbers = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] -# 生成所有的排列情况(取三次,每次都可以从10个数中选) -row_product = list(product(numbers, repeat=R)) -# 对每种情况从小到大排序,并剔除重复的情况 -row_cuts_set = set( - tuple(sorted(set(item for item in prod if item > 0))) for prod in row_product) -row_cuts_set = sorted(row_cuts_set) +# # 生成所有的排列情况(取三次,每次都可以从10个数中选) +# row_product = list(product(numbers, repeat=R)) +# # 对每种情况从小到大排序,并剔除重复的情况 +# row_cuts_set = set( +# tuple(sorted(set(item for item in prod if item > 0))) for prod in row_product) +# row_cuts_set = sorted(row_cuts_set) -col_product = list(product(numbers, repeat=C)) -col_cuts_set = set( - tuple(sorted(set(item for item in prod if item > 0))) for prod in col_product) -col_cuts_set = sorted(col_cuts_set) +# col_product = list(product(numbers, repeat=C)) +# col_cuts_set = set( +# tuple(sorted(set(item for item in prod if item > 0))) for prod in col_product) +# col_cuts_set = sorted(col_cuts_set) + +numbers = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] +row_cuts_set = [] +for i in range(R): + row_cuts_set += list(combinations(numbers, i+1)) + +col_cuts_set = [] +for i in range(C): + col_cuts_set += list(combinations(numbers, i+1)) total_iterations = len(row_cuts_set) * len(col_cuts_set) with tqdm(total=total_iterations, desc="Processing") as pbar: diff --git a/GA/main_parallel.py b/GA/main_parallel.py index 0548a0c..026f0cf 100644 --- a/GA/main_parallel.py +++ b/GA/main_parallel.py @@ -3,7 +3,7 @@ import math import yaml import numpy as np from utils import if_valid_partition, GA_solver -from itertools import product +from itertools import product, combinations import json from tqdm import tqdm from concurrent.futures import ProcessPoolExecutor, as_completed @@ -30,9 +30,9 @@ if __name__ == "__main__": # 重要:在 Windows 上必须加这一行 # --------------------------- # 需要修改的超参数 # --------------------------- - R = 3 - C = 1 - params_file = 'params_50_50_3' + R = 10 + C = 10 + params_file = 'params_100_100_6' batch_size = 60 # 控制一次最多并行多少个任务 with open(params_file + '.yml', 'r', encoding='utf-8') as file: @@ -54,19 +54,27 @@ if __name__ == "__main__": # 重要:在 Windows 上必须加这一行 battery_energy_capacity = params['battery_energy_capacity'] # 定义数字列表 - numbers = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] + numbers = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] # 生成所有的排列情况(取三次,每次都可以从10个数中选) - row_product = list(product(numbers, repeat=R)) - # 对每种情况从小到大排序,并剔除重复的情况 - row_cuts_set = set( - tuple(sorted(set(item for item in prod if item > 0))) for prod in row_product) - row_cuts_set = sorted(row_cuts_set) + # row_product = list(product(numbers, repeat=R)) + # # 对每种情况从小到大排序,并剔除重复的情况 + # row_cuts_set = set( + # tuple(sorted(set(item for item in prod if item > 0))) for prod in row_product) + # row_cuts_set = sorted(row_cuts_set) - col_product = list(product(numbers, repeat=C)) - col_cuts_set = set( - tuple(sorted(set(item for item in prod if item > 0))) for prod in col_product) - col_cuts_set = sorted(col_cuts_set) + # col_product = list(product(numbers, repeat=C)) + # col_cuts_set = set( + # tuple(sorted(set(item for item in prod if item > 0))) for prod in col_product) + # col_cuts_set = sorted(col_cuts_set) + + row_cuts_set = [] + for i in range(R): + row_cuts_set += list(combinations(numbers, i+1)) + + col_cuts_set = [] + for i in range(C): + col_cuts_set += list(combinations(numbers, i+1)) best_T = float('inf') best_solution = None diff --git a/greedy_solver.py b/greedy_solver.py index 8a12b63..fd53ccc 100644 --- a/greedy_solver.py +++ b/greedy_solver.py @@ -81,7 +81,7 @@ def main(): # --------------------------- # 需要修改的超参数 # --------------------------- - params_file = 'params_50_50_3' + params_file = 'params_100_100_6' # 读取参数 with open(params_file + '.yml', 'r', encoding='utf-8') as file: diff --git a/params_100_100_6.yml b/params_100_100_6.yml new file mode 100644 index 0000000..48b07dd --- /dev/null +++ b/params_100_100_6.yml @@ -0,0 +1,16 @@ +H : 100 # 区域高度,网格点之间的距离为25m(单位距离) +W : 100 # 区域宽度 +num_cars : 6 # 系统数量(车-巢-机系统个数) + +# 时间系数(单位:秒,每个网格一张照片) +flight_time_factor : 3 # 每张照片对应的飞行时间,无人机飞行速度为9.5m/s,拍摄照片的时间间隔为3s +comp_time_factor : 5 # 无人机上每张照片计算时间,5s +trans_time_factor : 0.3 # 每张照片传输时间,0.3s +car_time_factor : 100 # TODO 汽车每单位距离的移动时间,2s,加了一个放大因子50 +bs_time_factor : 5 # 机巢上每张照片计算时间 + +# 其他参数 +flight_energy_factor : 0.05 # 单位:分钟/张 +comp_energy_factor : 0.05 # TODO 计算能耗需要重新估计 +trans_energy_factor : 0.0025 +battery_energy_capacity : 20 # 无人机只进行飞行,续航为30分钟 \ No newline at end of file diff --git a/solutions/greedy_params_100_100_5.json b/solutions/greedy_params_100_100_5.json new file mode 100644 index 0000000..b066e1b --- /dev/null +++ b/solutions/greedy_params_100_100_5.json @@ -0,0 +1,73 @@ +{ + "row_boundaries": [ + 0.0, + 0.2, + 0.4, + 0.6000000000000001, + 0.8, + 1.0 + ], + "col_boundaries": [ + 0, + 0.12533333333333332, + 0.25066666666666665, + 0.376, + 0.5013333333333333, + 0.6266666666666666, + 0.7519999999999999, + 0.8773333333333332, + 1 + ], + "car_paths": [ + [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7 + ], + [ + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15 + ], + [ + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23 + ], + [ + 24, + 25, + 26, + 27, + 28, + 29, + 30, + 31 + ], + [ + 32, + 33, + 34, + 35, + 36, + 37, + 38, + 39 + ] + ] +} \ No newline at end of file diff --git a/solutions/greedy_params_100_100_6.json b/solutions/greedy_params_100_100_6.json new file mode 100644 index 0000000..af22109 --- /dev/null +++ b/solutions/greedy_params_100_100_6.json @@ -0,0 +1,77 @@ +{ + "row_boundaries": [ + 0.0, + 0.16666666666666666, + 0.3333333333333333, + 0.5, + 0.6666666666666666, + 0.8333333333333333, + 1.0 + ], + "col_boundaries": [ + 0, + 0.1504, + 0.3008, + 0.45120000000000005, + 0.6016, + 0.752, + 0.9024, + 1 + ], + "car_paths": [ + [ + 0, + 1, + 2, + 3, + 4, + 5, + 6 + ], + [ + 7, + 8, + 9, + 10, + 11, + 12, + 13 + ], + [ + 14, + 15, + 16, + 17, + 18, + 19, + 20 + ], + [ + 21, + 22, + 23, + 24, + 25, + 26, + 27 + ], + [ + 28, + 29, + 30, + 31, + 32, + 33, + 34 + ], + [ + 35, + 36, + 37, + 38, + 39, + 40, + 41 + ] + ] +} \ No newline at end of file diff --git a/solutions/trav_ga_params_100_100_5_parallel.json b/solutions/trav_ga_params_100_100_5_parallel.json new file mode 100644 index 0000000..3d68b01 --- /dev/null +++ b/solutions/trav_ga_params_100_100_5_parallel.json @@ -0,0 +1,74 @@ +{ + "row_boundaries": [ + 0.0, + 0.1, + 0.2, + 0.3, + 0.4, + 0.5, + 0.6, + 0.7, + 0.8, + 0.9, + 1.0 + ], + "col_boundaries": [ + 0.0, + 0.2, + 0.4, + 0.7, + 1.0 + ], + "car_paths": [ + [ + 17, + 5, + 4, + 0, + 1, + 2, + 6, + 10, + 14 + ], + [ + 18, + 13, + 9, + 8, + 12, + 16, + 20, + 32, + 21 + ], + [ + 22, + 26, + 30, + 34, + 39, + 35, + 31, + 27, + 19 + ], + [ + 25, + 24, + 28, + 36, + 33, + 37, + 38, + 29 + ], + [ + 15, + 23, + 7, + 11, + 3 + ] + ] +} \ No newline at end of file diff --git a/GA/GA_MTSP.py b/test_new_GA/GA_MTSP.bak similarity index 100% rename from GA/GA_MTSP.py rename to test_new_GA/GA_MTSP.bak diff --git a/test_new_GA/GA_MTSP.py b/test_new_GA/GA_MTSP.py new file mode 100644 index 0000000..e8e62be --- /dev/null +++ b/test_new_GA/GA_MTSP.py @@ -0,0 +1,278 @@ +import numpy as np +import random +import math +import matplotlib.pyplot as plt +import time +# 设置随机种子 +np.random.seed(42) +random.seed(42) + + +class MTSP_GA: + def __init__(self, cities, params, rectangles, population_size=200, max_iterations=1500): + """ + 初始化遗传算法求解器 + Args: + cities: 城市坐标数组,第一个城市为起始点 + vehicle_num: 车辆数量 + population_size: 种群大小 + max_iterations: 最大迭代次数 + """ + H = params['H'] + W = params['W'] + k = params['num_cars'] + + flight_time_factor = params['flight_time_factor'] + comp_time_factor = params['comp_time_factor'] + trans_time_factor = params['trans_time_factor'] + car_time_factor = params['car_time_factor'] + bs_time_factor = params['bs_time_factor'] + + flight_energy_factor = params['flight_energy_factor'] + comp_energy_factor = params['comp_energy_factor'] + trans_energy_factor = params['trans_energy_factor'] + battery_energy_capacity = params['battery_energy_capacity'] + + self.cities = np.array(cities) + self.city_count = len(cities) + self.vehicle_num = k + self.origin = 0 # 起始点 + self.rectangles = rectangles + self.car_time_factor = car_time_factor + self.H = H + self.W = W + + # GA参数 + self.population_size = population_size + self.max_iterations = max_iterations + self.retain_rate = 0.3 # 强者存活率 + self.random_rate = 0.5 # 弱者存活概率 + self.mutation_rate = 0.5 # 变异率 + + # 计算距离矩阵 + self.distance_matrix = self._compute_distance_matrix() + + # 记录收敛过程 + self.distance_history = [] + self.best_path_history = [] + + def _compute_distance_matrix(self): + """计算城市间距离矩阵""" + distance = np.zeros((self.city_count, self.city_count)) + for i in range(self.city_count): + for j in range(self.city_count): + distance[i][j] = math.sqrt( + (self.cities[i][0] - self.cities[j][0]) ** 2 + + (self.cities[i][1] - self.cities[j][1]) ** 2 + ) + return distance + + def _create_individual(self): + """生成初始个体""" + index = [i for i in range(self.city_count)] + index.remove(self.origin) + a = int(np.floor(len(index)/self.vehicle_num)) + X = [] + for i in range(self.vehicle_num): + if i < self.vehicle_num-1: + x = index[a*i:a*(i+1)] + else: + x = index[a*i:] + X.append(x) + return X + + def _get_total_distance(self, X): + """计算路径总距离""" + # 根据car_paths计算时间 + for car_idx in range(self.vehicle_num): + car_path = X[car_idx] + flight_time = sum(self.rectangles[point - 1]['flight_time'] + for point in car_path) + bs_time = sum(self.rectangles[point - 1] + ['bs_time'] for point in car_path) + + # 计算车的移动时间,首先在轨迹的首尾添加上大区域中心 + car_time = 0 + car_time += self.distance_matrix[self.origin][car_path[0]] * self.car_time_factor + car_time += self.distance_matrix[self.origin][car_path[-1]] * self.car_time_factor + for i in range(len(car_path) - 1): + first_point = car_path[i] + second_point = car_path[i + 1] + car_time += self.distance_matrix[first_point][second_point] * self.car_time_factor + + system_time = max(flight_time + car_time, bs_time) + + return system_time + + def _selection(self, population): + """选择操作""" + graded = [[self._get_total_distance(x), x] for x in population] + graded = [x[1] for x in sorted(graded)] + retain_length = int(len(graded) * self.retain_rate) + parents = graded[:retain_length] + + for chromosome in graded[retain_length:]: + if random.random() < self.random_rate: + parents.append(chromosome) + return parents + + def _crossover(self, parents): + """交叉操作""" + target_count = self.population_size - len(parents) + children = [] + while len(children) < target_count: + male_index = random.randint(0, len(parents) - 1) + female_index = random.randint(0, len(parents) - 1) + if male_index != female_index: + male = parents[male_index] + female = parents[female_index] + + gene1 = [] + gene2 = [] + for i in range(len(male)): + gene1 += male[i] + gene2 += female[i] + + left = random.randint(0, len(gene1)//2) + right = random.randint(left + 1, len(gene1)) + cut = gene1[left:right] + copy = gene2.copy() + for j in cut: + copy.remove(j) + + child = copy + cut + a = int(np.floor(len(child)/self.vehicle_num)) + child_c = [] + for i in range(self.vehicle_num): + if i < self.vehicle_num - 1: + x = child[a * i:a * (i + 1)] + else: + x = child[a * i:] + child_c.append(x) + children.append(child_c) + return children + + def _mutation(self, children): + """变异操作""" + for i in range(len(children)): + if random.random() < self.mutation_rate: + child = children[i] + for j in range(int(np.floor(len(child)/2))): + a = 2*j + u = random.randint(1, len(child[a]) - 1) + w = random.randint(1, len(child[a+1]) - 1) + child_1 = child[a][:u].copy() + child_2 = child[a][u:].copy() + child_3 = child[a+1][:w].copy() + child_4 = child[a+1][w:].copy() + child_a = child_1+child_3 + child_b = child_2+child_4 + child[a] = child_a + child[a+1] = child_b + children[i] = child.copy() + return children + + def _get_best_solution(self, population): + """获取最优解""" + graded = [[self._get_total_distance(x), x] for x in population] + graded = sorted(graded, key=lambda x: x[0]) + return graded[0][0], graded[0][1] + + def solve(self): + """ + 求解MTSP,加入早停机制 + 当连续50轮没有更好的解时停止迭代 + """ + # 初始化种群 + population = [self._create_individual() + for _ in range(self.population_size)] + + # 初始化早停相关变量 + best_distance = float('inf') + early_stop_counter = 0 + early_stop_threshold = 100 + + # 迭代优化 + for i in range(self.max_iterations): + parents = self._selection(population) + children = self._crossover(parents) + children = self._mutation(children) + population = parents + children + + # 记录当前最优解 + current_distance, current_path = self._get_best_solution( + population) + self.distance_history.append(current_distance) + self.best_path_history.append(current_path) + + # 早停判断 + if current_distance < best_distance: + best_distance = current_distance + best_path = current_path + # early_stop_counter = 0 # 重置计数器 + # else: + # early_stop_counter += 1 + + # # 如果连续50轮没有更好的解,则停止迭代 + # if early_stop_counter >= early_stop_threshold: + # # print( + # # f"Early stopping at iteration {i} due to no improvement in {early_stop_threshold} iterations") + # break + + # 返回最优解 + return best_distance, best_path + + def plot_convergence(self): + """绘制收敛曲线""" + plt.plot(range(len(self.distance_history)), self.distance_history) + plt.xlabel('Iteration') + plt.ylabel('Total Distance') + plt.title('Convergence Curve') + plt.show() + + +def main(): + # 城市坐标 + cities = np.array([ + (456, 320), # 起点(基地) + (228, 0), + (912, 0), + (0, 80), + (114, 80), + (570, 160), + (798, 160), + (342, 240), + (684, 240), + (570, 400), + (912, 400), + (114, 480), + (228, 480), + (342, 560), + (684, 560), + (0, 640), + (798, 640) + ]) + + # 创建求解器实例 + solver = MTSP_GA( + cities=cities, + vehicle_num=4, + population_size=200, + max_iterations=1500 + ) + + # 求解 + start_time = time.time() + best_distance, best_path = solver.solve() + end_time = time.time() + + # 输出结果 + print(f"最优总距离: {best_distance:.2f}") + print("最优路径方案:") + for i, path in enumerate(best_path): + print(f"车辆{i+1}的路径: {path}") + print(f"求解时间: {end_time - start_time:.2f}秒") + + +if __name__ == "__main__": + main() diff --git a/test_new_GA/main_parallel.py b/test_new_GA/main_parallel.py new file mode 100644 index 0000000..e34da02 --- /dev/null +++ b/test_new_GA/main_parallel.py @@ -0,0 +1,132 @@ +import random +import math +import yaml +import numpy as np +from utils import if_valid_partition +from itertools import product +import json +from tqdm import tqdm +from concurrent.futures import ProcessPoolExecutor, as_completed +from GA_MTSP import MTSP_GA + + +def process_partition(row_cuts, col_cuts, params): + row_boundaries = [0.0] + list(row_cuts) + [1.0] + col_boundaries = [0.0] + list(col_cuts) + [1.0] + rectangles = if_valid_partition(row_boundaries, col_boundaries, params) + + if not rectangles: + return None # 过滤无效划分 + + cities = [(params['H'] / 2.0, params['W'] / 2.0)] + for rec in rectangles: + cities.append(rec['center']) + sovler = MTSP_GA( + cities=cities, params=params, rectangles=rectangles, population_size=200, max_iterations=2000) + current_time, current_solution = sovler.solve() + + return (current_solution, current_time, row_boundaries, col_boundaries, rectangles) + + +if __name__ == "__main__": # 重要:在 Windows 上必须加这一行 + np.random.seed(42) + random.seed(42) + + # --------------------------- + # 需要修改的超参数 + # --------------------------- + R = 3 + C = 1 + params_file = 'params_50_50_3' + batch_size = 60 # 控制一次最多并行多少个任务 + + with open(params_file + '.yml', 'r', encoding='utf-8') as file: + params = yaml.safe_load(file) + + H = params['H'] + W = params['W'] + k = params['num_cars'] + + flight_time_factor = params['flight_time_factor'] + comp_time_factor = params['comp_time_factor'] + trans_time_factor = params['trans_time_factor'] + car_time_factor = params['car_time_factor'] + bs_time_factor = params['bs_time_factor'] + + flight_energy_factor = params['flight_energy_factor'] + comp_energy_factor = params['comp_energy_factor'] + trans_energy_factor = params['trans_energy_factor'] + battery_energy_capacity = params['battery_energy_capacity'] + + # 定义数字列表 + numbers = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] + + # 生成所有的排列情况(取三次,每次都可以从10个数中选) + row_product = list(product(numbers, repeat=R)) + # 对每种情况从小到大排序,并剔除重复的情况 + row_cuts_set = set( + tuple(sorted(set(item for item in prod if item > 0))) for prod in row_product) + row_cuts_set = sorted(row_cuts_set) + + col_product = list(product(numbers, repeat=C)) + col_cuts_set = set( + tuple(sorted(set(item for item in prod if item > 0))) for prod in col_product) + col_cuts_set = sorted(col_cuts_set) + + best_T = float('inf') + best_solution = None + best_row_boundaries = None + best_col_boundaries = None + + all_tasks = [(row_cuts, col_cuts) + for row_cuts in row_cuts_set for col_cuts in col_cuts_set] + total_iterations = len(all_tasks) + + with ProcessPoolExecutor(max_workers=batch_size) as executor: + futures = set() + results = [] + + with tqdm(total=total_iterations) as pbar: + for task in all_tasks: + if len(futures) >= batch_size: # 如果并行任务数达到 batch_size,等待已有任务完成 + for future in as_completed(futures): + results.append(future.result()) + pbar.update(1) # 更新进度条 + futures.clear() # 清空已完成的任务 + + futures.add(executor.submit( + process_partition, *task, params)) # 提交新任务 + + # 处理剩余未完成的任务 + for future in as_completed(futures): + results.append(future.result()) + pbar.update(1) + + # 处理计算结果,找到最优解 + for result in results: + if result: + current_solution, current_time, row_boundaries, col_boundaries, rectangles = result + if current_time < best_T: + best_T = current_time + best_solution = current_solution + best_row_boundaries = row_boundaries + best_col_boundaries = col_boundaries + + # 解析最佳路径 + car_paths = [[x-1 for x in sublist] + for sublist in best_solution] + + # 输出最佳方案 + print("Best solution:", best_solution) + print("Time:", best_T) + print("Row boundaries:", best_row_boundaries) + print("Col boundaries:", best_col_boundaries) + print("Car Paths:", car_paths) + + output_data = { + 'row_boundaries': best_row_boundaries, + 'col_boundaries': best_col_boundaries, + 'car_paths': car_paths + } + with open(f'./solutions/trav_ga_{params_file}_parallel.json', 'w', encoding='utf-8') as file: + json.dump(output_data, file, ensure_ascii=False, indent=4) diff --git a/test_new_GA/utils.py b/test_new_GA/utils.py new file mode 100644 index 0000000..32dd491 --- /dev/null +++ b/test_new_GA/utils.py @@ -0,0 +1,58 @@ +import numpy as np + + +def if_valid_partition(row_boundaries, col_boundaries, params): + H = params['H'] + W = params['W'] + k = params['num_cars'] + + flight_time_factor = params['flight_time_factor'] + comp_time_factor = params['comp_time_factor'] + trans_time_factor = params['trans_time_factor'] + car_time_factor = params['car_time_factor'] + bs_time_factor = params['bs_time_factor'] + + flight_energy_factor = params['flight_energy_factor'] + comp_energy_factor = params['comp_energy_factor'] + trans_energy_factor = params['trans_energy_factor'] + battery_energy_capacity = params['battery_energy_capacity'] + + # 根据分割边界生成所有矩形任务 + rectangles = [] + for i in range(len(row_boundaries) - 1): + for j in range(len(col_boundaries) - 1): + r1 = row_boundaries[i] + r2 = row_boundaries[i + 1] + c1 = col_boundaries[j] + c2 = col_boundaries[j + 1] + d = (r2 - r1) * H * (c2 - c1) * W # 任务的照片数量(矩形面积) + + # 求解rho + rho_time_limit = (flight_time_factor - trans_time_factor) / \ + (comp_time_factor - trans_time_factor) + rho_energy_limit = (battery_energy_capacity - flight_energy_factor * d - trans_energy_factor * d) / \ + (comp_energy_factor * d - trans_energy_factor * d) + if rho_energy_limit < 0: + return [] + + rho = min(rho_time_limit, rho_energy_limit) + flight_time = flight_time_factor * d + comp_time = comp_time_factor * rho * d + trans_time = trans_time_factor * (1 - rho) * d + bs_time = bs_time_factor * (1 - rho) * d + + # 计算任务矩形中心,用于后续车辆移动时间计算 + center_r = (r1 + r2) / 2.0 * H + center_c = (c1 + c2) / 2.0 * W + + rectangles.append({ + # 'r1': r1, 'r2': r2, 'c1': c1, 'c2': c2, + 'd': d, + 'rho': rho, + 'flight_time': flight_time, + 'comp_time': comp_time, + 'trans_time': trans_time, + 'bs_time': bs_time, + 'center': (center_r, center_c) + }) + return rectangles diff --git a/visualization.py b/visualization.py index c62eb52..e183909 100644 --- a/visualization.py +++ b/visualization.py @@ -14,8 +14,8 @@ def visualize_solution(row_boundaries, col_boundaries, car_paths_coords, W, H, r # 设置英文标题和标签 # ax.set_title("Monte Carlo", fontsize=12) - # ax.set_title("Greedy", fontsize=12) - ax.set_title("Enumeration-Genetic Algorithm", fontsize=12) + ax.set_title("Greedy", fontsize=12) + # ax.set_title("Enumeration-Genetic Algorithm", fontsize=12) # ax.set_title("DQN fine-tuning", fontsize=12) ax.set_xlabel("Region Width", fontsize=10) @@ -199,8 +199,8 @@ if __name__ == "__main__": # --------------------------- # 需要修改的超参数 # --------------------------- - params_file = 'params_50_50_3' - solution_file = r'solutions\dqn_params_50_50_3_2.json' + params_file = 'params_100_100_6' + solution_file = r'solutions\greedy_params_100_100_6.json' with open(params_file + '.yml', 'r', encoding='utf-8') as file: params = yaml.safe_load(file)