HPCC2025/test_new_GA/main_parallel.py
2025-04-03 17:24:54 +08:00

133 lines
4.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import random
import math
import yaml
import numpy as np
from utils import if_valid_partition
from itertools import product
import json
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor, as_completed
from GA_MTSP import MTSP_GA
def process_partition(row_cuts, col_cuts, params):
row_boundaries = [0.0] + list(row_cuts) + [1.0]
col_boundaries = [0.0] + list(col_cuts) + [1.0]
rectangles = if_valid_partition(row_boundaries, col_boundaries, params)
if not rectangles:
return None # 过滤无效划分
cities = [(params['H'] / 2.0, params['W'] / 2.0)]
for rec in rectangles:
cities.append(rec['center'])
sovler = MTSP_GA(
cities=cities, params=params, rectangles=rectangles, population_size=200, max_iterations=2000)
current_time, current_solution = sovler.solve()
return (current_solution, current_time, row_boundaries, col_boundaries, rectangles)
if __name__ == "__main__": # 重要:在 Windows 上必须加这一行
np.random.seed(42)
random.seed(42)
# ---------------------------
# 需要修改的超参数
# ---------------------------
R = 3
C = 1
params_file = 'params_50_50_3'
batch_size = 60 # 控制一次最多并行多少个任务
with open(params_file + '.yml', 'r', encoding='utf-8') as file:
params = yaml.safe_load(file)
H = params['H']
W = params['W']
k = params['num_cars']
flight_time_factor = params['flight_time_factor']
comp_time_factor = params['comp_time_factor']
trans_time_factor = params['trans_time_factor']
car_time_factor = params['car_time_factor']
bs_time_factor = params['bs_time_factor']
flight_energy_factor = params['flight_energy_factor']
comp_energy_factor = params['comp_energy_factor']
trans_energy_factor = params['trans_energy_factor']
battery_energy_capacity = params['battery_energy_capacity']
# 定义数字列表
numbers = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
# 生成所有的排列情况取三次每次都可以从10个数中选
row_product = list(product(numbers, repeat=R))
# 对每种情况从小到大排序,并剔除重复的情况
row_cuts_set = set(
tuple(sorted(set(item for item in prod if item > 0))) for prod in row_product)
row_cuts_set = sorted(row_cuts_set)
col_product = list(product(numbers, repeat=C))
col_cuts_set = set(
tuple(sorted(set(item for item in prod if item > 0))) for prod in col_product)
col_cuts_set = sorted(col_cuts_set)
best_T = float('inf')
best_solution = None
best_row_boundaries = None
best_col_boundaries = None
all_tasks = [(row_cuts, col_cuts)
for row_cuts in row_cuts_set for col_cuts in col_cuts_set]
total_iterations = len(all_tasks)
with ProcessPoolExecutor(max_workers=batch_size) as executor:
futures = set()
results = []
with tqdm(total=total_iterations) as pbar:
for task in all_tasks:
if len(futures) >= batch_size: # 如果并行任务数达到 batch_size等待已有任务完成
for future in as_completed(futures):
results.append(future.result())
pbar.update(1) # 更新进度条
futures.clear() # 清空已完成的任务
futures.add(executor.submit(
process_partition, *task, params)) # 提交新任务
# 处理剩余未完成的任务
for future in as_completed(futures):
results.append(future.result())
pbar.update(1)
# 处理计算结果,找到最优解
for result in results:
if result:
current_solution, current_time, row_boundaries, col_boundaries, rectangles = result
if current_time < best_T:
best_T = current_time
best_solution = current_solution
best_row_boundaries = row_boundaries
best_col_boundaries = col_boundaries
# 解析最佳路径
car_paths = [[x-1 for x in sublist]
for sublist in best_solution]
# 输出最佳方案
print("Best solution:", best_solution)
print("Time:", best_T)
print("Row boundaries:", best_row_boundaries)
print("Col boundaries:", best_col_boundaries)
print("Car Paths:", car_paths)
output_data = {
'row_boundaries': best_row_boundaries,
'col_boundaries': best_col_boundaries,
'car_paths': car_paths
}
with open(f'./solutions/trav_ga_{params_file}_parallel.json', 'w', encoding='utf-8') as file:
json.dump(output_data, file, ensure_ascii=False, indent=4)