HPCC2025/GA/main.py
weixin_46229132 981681c1bd 修改dqn bug
2025-04-01 20:45:13 +08:00

115 lines
4.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import random
import math
import yaml
import numpy as np
from utils import if_valid_partition, GA_solver
from itertools import product
import json
from tqdm import tqdm
np.random.seed(42)
random.seed(42)
best_T = float('inf')
best_solution = None
best_row_boundaries = None
best_col_boundaries = None
# ---------------------------
# 需要修改的超参数
# ---------------------------
R = 3
C = 3
params_file = 'params_50_50_3'
with open(params_file + '.yml', 'r', encoding='utf-8') as file:
params = yaml.safe_load(file)
H = params['H']
W = params['W']
k = params['num_cars']
flight_time_factor = params['flight_time_factor']
comp_time_factor = params['comp_time_factor']
trans_time_factor = params['trans_time_factor']
car_time_factor = params['car_time_factor']
bs_time_factor = params['bs_time_factor']
flight_energy_factor = params['flight_energy_factor']
comp_energy_factor = params['comp_energy_factor']
trans_energy_factor = params['trans_energy_factor']
battery_energy_capacity = params['battery_energy_capacity']
# 定义数字列表
numbers = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
# 生成所有的排列情况取三次每次都可以从10个数中选
row_product = list(product(numbers, repeat=R))
# 对每种情况从小到大排序,并剔除重复的情况
row_cuts_set = set(
tuple(sorted(set(item for item in prod if item > 0))) for prod in row_product)
row_cuts_set = sorted(row_cuts_set)
col_product = list(product(numbers, repeat=C))
col_cuts_set = set(
tuple(sorted(set(item for item in prod if item > 0))) for prod in col_product)
col_cuts_set = sorted(col_cuts_set)
total_iterations = len(row_cuts_set) * len(col_cuts_set)
with tqdm(total=total_iterations, desc="Processing") as pbar:
for row_cuts in row_cuts_set:
for col_cuts in col_cuts_set:
row_boundaries = [0.0] + list(row_cuts) + [1.0]
col_boundaries = [0.0] + list(col_cuts) + [1.0]
# 这里面的距离不再是比例,而是真实距离!
rectrangles = if_valid_partition(
row_boundaries, col_boundaries, params)
if not rectrangles:
pbar.update(1)
continue
else:
# 使用遗传算法求出每一种网格划分的可行解,然后选择其中的最优解
current_solution, current_time, to_process_idx = GA_solver(
rectrangles, params)
if current_time < best_T:
best_T = current_time
best_solution = current_solution
best_row_boundaries = row_boundaries
best_col_boundaries = col_boundaries
# 将best_solution分解成每个车队的路径
found_start_points_indices = []
for i in range(len(best_solution)):
if best_solution[i] in to_process_idx:
found_start_points_indices.append(i)
car_paths = []
for j in range(len(found_start_points_indices) - 1):
from_index = found_start_points_indices[j]
end_index = found_start_points_indices[j + 1]
car_path = []
for k in range(from_index, end_index + 1):
rectrangle_idx = best_solution[k]
if rectrangle_idx not in to_process_idx:
car_path.append(rectrangle_idx - 1)
if car_path:
car_paths.append(car_path)
pbar.update(1)
# 输出最佳方案
print("Best solution:", best_solution)
print("Time:", best_T)
print("Row boundaries:", best_row_boundaries)
print("Col boundaries:", best_col_boundaries)
print("Car Paths:", car_paths)
output_data = {
'row_boundaries': best_row_boundaries,
'col_boundaries': best_col_boundaries,
'car_paths': car_paths
}
with open(f'./solutions/trav_ga_{params_file}.json', 'w', encoding='utf-8') as file:
json.dump(output_data, file, ensure_ascii=False, indent=4)