HPCC2025/GA/main_parallel.py

149 lines
5.4 KiB
Python
Raw Permalink Normal View History

2025-03-22 17:16:58 +08:00
import random
import math
import yaml
import numpy as np
from utils import if_valid_partition, GA_solver
2025-04-03 17:24:54 +08:00
from itertools import product, combinations
2025-03-22 17:16:58 +08:00
import json
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor, as_completed
2025-03-22 21:43:11 +08:00
def process_partition(row_cuts, col_cuts, params):
2025-03-22 17:16:58 +08:00
row_boundaries = [0.0] + list(row_cuts) + [1.0]
col_boundaries = [0.0] + list(col_cuts) + [1.0]
rectrangles = if_valid_partition(row_boundaries, col_boundaries, params)
if not rectrangles:
return None # 过滤无效划分
2025-03-22 21:43:11 +08:00
current_solution, current_time, to_process_idx = GA_solver(
rectrangles, params)
2025-03-22 17:16:58 +08:00
return (current_solution, current_time, row_boundaries, col_boundaries, to_process_idx, rectrangles)
if __name__ == "__main__": # 重要:在 Windows 上必须加这一行
2025-03-22 21:43:11 +08:00
np.random.seed(42)
random.seed(42)
2025-03-24 17:09:51 +08:00
# ---------------------------
# 需要修改的超参数
# ---------------------------
2025-04-05 10:36:03 +08:00
R = 7
C = 7
2025-04-03 17:24:54 +08:00
params_file = 'params_100_100_6'
2025-03-31 14:23:29 +08:00
batch_size = 60 # 控制一次最多并行多少个任务
2025-03-24 17:09:51 +08:00
2025-03-22 21:43:11 +08:00
with open(params_file + '.yml', 'r', encoding='utf-8') as file:
params = yaml.safe_load(file)
H = params['H']
W = params['W']
k = params['num_cars']
flight_time_factor = params['flight_time_factor']
comp_time_factor = params['comp_time_factor']
trans_time_factor = params['trans_time_factor']
car_time_factor = params['car_time_factor']
bs_time_factor = params['bs_time_factor']
flight_energy_factor = params['flight_energy_factor']
comp_energy_factor = params['comp_energy_factor']
trans_energy_factor = params['trans_energy_factor']
battery_energy_capacity = params['battery_energy_capacity']
# 定义数字列表
2025-04-03 17:24:54 +08:00
numbers = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
2025-03-22 21:43:11 +08:00
# 生成所有的排列情况取三次每次都可以从10个数中选
2025-04-03 17:24:54 +08:00
# row_product = list(product(numbers, repeat=R))
# # 对每种情况从小到大排序,并剔除重复的情况
# row_cuts_set = set(
# tuple(sorted(set(item for item in prod if item > 0))) for prod in row_product)
# row_cuts_set = sorted(row_cuts_set)
# col_product = list(product(numbers, repeat=C))
# col_cuts_set = set(
# tuple(sorted(set(item for item in prod if item > 0))) for prod in col_product)
# col_cuts_set = sorted(col_cuts_set)
row_cuts_set = []
for i in range(R):
row_cuts_set += list(combinations(numbers, i+1))
col_cuts_set = []
for i in range(C):
col_cuts_set += list(combinations(numbers, i+1))
2025-03-22 21:43:11 +08:00
2025-03-22 17:16:58 +08:00
best_T = float('inf')
best_solution = None
best_row_boundaries = None
best_col_boundaries = None
2025-03-22 21:43:11 +08:00
all_tasks = [(row_cuts, col_cuts)
for row_cuts in row_cuts_set for col_cuts in col_cuts_set]
2025-03-22 17:16:58 +08:00
total_iterations = len(all_tasks)
with ProcessPoolExecutor(max_workers=batch_size) as executor:
futures = set()
results = []
with tqdm(total=total_iterations) as pbar:
for task in all_tasks:
if len(futures) >= batch_size: # 如果并行任务数达到 batch_size等待已有任务完成
for future in as_completed(futures):
results.append(future.result())
pbar.update(1) # 更新进度条
futures.clear() # 清空已完成的任务
2025-03-24 15:42:42 +08:00
futures.add(executor.submit(
process_partition, *task, params)) # 提交新任务
2025-03-22 17:16:58 +08:00
# 处理剩余未完成的任务
for future in as_completed(futures):
results.append(future.result())
pbar.update(1)
# 处理计算结果,找到最优解
for result in results:
if result:
current_solution, current_time, row_boundaries, col_boundaries, to_process_idx, rectrangles = result
if current_time < best_T:
best_T = current_time
best_solution = current_solution
best_row_boundaries = row_boundaries
best_col_boundaries = col_boundaries
# 解析最佳路径
found_start_points_indices = []
for i in range(len(best_solution)):
if best_solution[i] in to_process_idx:
found_start_points_indices.append(i)
car_paths = []
for j in range(len(found_start_points_indices) - 1):
from_index = found_start_points_indices[j]
end_index = found_start_points_indices[j + 1]
car_path = []
for k in range(from_index, end_index + 1):
rectrangle_idx = best_solution[k]
if rectrangle_idx not in to_process_idx:
2025-03-24 15:42:42 +08:00
car_path.append(rectrangle_idx - 1)
2025-03-22 21:43:11 +08:00
if car_path:
car_paths.append(car_path)
2025-03-22 17:16:58 +08:00
# 输出最佳方案
print("Best solution:", best_solution)
print("Time:", best_T)
print("Row boundaries:", best_row_boundaries)
print("Col boundaries:", best_col_boundaries)
2025-03-24 17:09:51 +08:00
print("Car Paths:", car_paths)
2025-03-22 17:16:58 +08:00
output_data = {
2025-03-22 21:43:11 +08:00
'row_boundaries': best_row_boundaries,
'col_boundaries': best_col_boundaries,
2025-03-22 17:16:58 +08:00
'car_paths': car_paths
}
2025-03-22 21:43:11 +08:00
with open(f'./solutions/trav_ga_{params_file}_parallel.json', 'w', encoding='utf-8') as file:
2025-03-22 17:16:58 +08:00
json.dump(output_data, file, ensure_ascii=False, indent=4)