修改划分列举方法

This commit is contained in:
weixin_46229132 2025-04-03 17:24:54 +08:00
parent adaf8cc50e
commit 23aafc2998
13 changed files with 762 additions and 37 deletions

View File

@ -16,7 +16,7 @@ class GA(object):
self.location = data self.location = data
self.to_process_idx = to_process_idx self.to_process_idx = to_process_idx
self.rectangles = rectangles self.rectangles = rectangles
self.epochs = 300 self.epochs = 1000
self.ga_choose_ratio = 0.2 self.ga_choose_ratio = 0.2
self.mutate_ratio = 0.05 self.mutate_ratio = 0.05
# fruits中存每一个个体是下标的list # fruits中存每一个个体是下标的list
@ -314,7 +314,7 @@ class GA(object):
early_stop_cnt = 0 early_stop_cnt = 0
else: else:
early_stop_cnt += 1 early_stop_cnt += 1
if early_stop_cnt == 50: # 若连续50次没有性能提升则早停 if early_stop_cnt == 100: # 若连续50次没有性能提升则早停
break break
self.best_record.append(1.0 / best_score) self.best_record.append(1.0 / best_score)
best_length = 1.0 / best_score best_length = 1.0 / best_score

View File

@ -3,7 +3,7 @@ import math
import yaml import yaml
import numpy as np import numpy as np
from utils import if_valid_partition, GA_solver from utils import if_valid_partition, GA_solver
from itertools import product from itertools import product, combinations
import json import json
from tqdm import tqdm from tqdm import tqdm
@ -18,9 +18,9 @@ best_col_boundaries = None
# --------------------------- # ---------------------------
# 需要修改的超参数 # 需要修改的超参数
# --------------------------- # ---------------------------
R = 3 R = 10
C = 3 C = 10
params_file = 'params_50_50_3' params_file = 'params_100_100_5'
with open(params_file + '.yml', 'r', encoding='utf-8') as file: with open(params_file + '.yml', 'r', encoding='utf-8') as file:
@ -41,20 +41,29 @@ comp_energy_factor = params['comp_energy_factor']
trans_energy_factor = params['trans_energy_factor'] trans_energy_factor = params['trans_energy_factor']
battery_energy_capacity = params['battery_energy_capacity'] battery_energy_capacity = params['battery_energy_capacity']
# 定义数字列表 # # 定义数字列表
numbers = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] # numbers = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
# 生成所有的排列情况取三次每次都可以从10个数中选 # # 生成所有的排列情况取三次每次都可以从10个数中选
row_product = list(product(numbers, repeat=R)) # row_product = list(product(numbers, repeat=R))
# 对每种情况从小到大排序,并剔除重复的情况 # # 对每种情况从小到大排序,并剔除重复的情况
row_cuts_set = set( # row_cuts_set = set(
tuple(sorted(set(item for item in prod if item > 0))) for prod in row_product) # tuple(sorted(set(item for item in prod if item > 0))) for prod in row_product)
row_cuts_set = sorted(row_cuts_set) # row_cuts_set = sorted(row_cuts_set)
col_product = list(product(numbers, repeat=C)) # col_product = list(product(numbers, repeat=C))
col_cuts_set = set( # col_cuts_set = set(
tuple(sorted(set(item for item in prod if item > 0))) for prod in col_product) # tuple(sorted(set(item for item in prod if item > 0))) for prod in col_product)
col_cuts_set = sorted(col_cuts_set) # col_cuts_set = sorted(col_cuts_set)
numbers = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
row_cuts_set = []
for i in range(R):
row_cuts_set += list(combinations(numbers, i+1))
col_cuts_set = []
for i in range(C):
col_cuts_set += list(combinations(numbers, i+1))
total_iterations = len(row_cuts_set) * len(col_cuts_set) total_iterations = len(row_cuts_set) * len(col_cuts_set)
with tqdm(total=total_iterations, desc="Processing") as pbar: with tqdm(total=total_iterations, desc="Processing") as pbar:

View File

@ -3,7 +3,7 @@ import math
import yaml import yaml
import numpy as np import numpy as np
from utils import if_valid_partition, GA_solver from utils import if_valid_partition, GA_solver
from itertools import product from itertools import product, combinations
import json import json
from tqdm import tqdm from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor, as_completed from concurrent.futures import ProcessPoolExecutor, as_completed
@ -30,9 +30,9 @@ if __name__ == "__main__": # 重要:在 Windows 上必须加这一行
# --------------------------- # ---------------------------
# 需要修改的超参数 # 需要修改的超参数
# --------------------------- # ---------------------------
R = 3 R = 10
C = 1 C = 10
params_file = 'params_50_50_3' params_file = 'params_100_100_6'
batch_size = 60 # 控制一次最多并行多少个任务 batch_size = 60 # 控制一次最多并行多少个任务
with open(params_file + '.yml', 'r', encoding='utf-8') as file: with open(params_file + '.yml', 'r', encoding='utf-8') as file:
@ -54,19 +54,27 @@ if __name__ == "__main__": # 重要:在 Windows 上必须加这一行
battery_energy_capacity = params['battery_energy_capacity'] battery_energy_capacity = params['battery_energy_capacity']
# 定义数字列表 # 定义数字列表
numbers = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] numbers = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
# 生成所有的排列情况取三次每次都可以从10个数中选 # 生成所有的排列情况取三次每次都可以从10个数中选
row_product = list(product(numbers, repeat=R)) # row_product = list(product(numbers, repeat=R))
# 对每种情况从小到大排序,并剔除重复的情况 # # 对每种情况从小到大排序,并剔除重复的情况
row_cuts_set = set( # row_cuts_set = set(
tuple(sorted(set(item for item in prod if item > 0))) for prod in row_product) # tuple(sorted(set(item for item in prod if item > 0))) for prod in row_product)
row_cuts_set = sorted(row_cuts_set) # row_cuts_set = sorted(row_cuts_set)
col_product = list(product(numbers, repeat=C)) # col_product = list(product(numbers, repeat=C))
col_cuts_set = set( # col_cuts_set = set(
tuple(sorted(set(item for item in prod if item > 0))) for prod in col_product) # tuple(sorted(set(item for item in prod if item > 0))) for prod in col_product)
col_cuts_set = sorted(col_cuts_set) # col_cuts_set = sorted(col_cuts_set)
row_cuts_set = []
for i in range(R):
row_cuts_set += list(combinations(numbers, i+1))
col_cuts_set = []
for i in range(C):
col_cuts_set += list(combinations(numbers, i+1))
best_T = float('inf') best_T = float('inf')
best_solution = None best_solution = None

View File

@ -81,7 +81,7 @@ def main():
# --------------------------- # ---------------------------
# 需要修改的超参数 # 需要修改的超参数
# --------------------------- # ---------------------------
params_file = 'params_50_50_3' params_file = 'params_100_100_6'
# 读取参数 # 读取参数
with open(params_file + '.yml', 'r', encoding='utf-8') as file: with open(params_file + '.yml', 'r', encoding='utf-8') as file:

16
params_100_100_6.yml Normal file
View File

@ -0,0 +1,16 @@
H : 100 # 区域高度网格点之间的距离为25m单位距离
W : 100 # 区域宽度
num_cars : 6 # 系统数量(车-巢-机系统个数)
# 时间系数(单位:秒,每个网格一张照片)
flight_time_factor : 3 # 每张照片对应的飞行时间无人机飞行速度为9.5m/s拍摄照片的时间间隔为3s
comp_time_factor : 5 # 无人机上每张照片计算时间5s
trans_time_factor : 0.3 # 每张照片传输时间0.3s
car_time_factor : 100 # TODO 汽车每单位距离的移动时间2s加了一个放大因子50
bs_time_factor : 5 # 机巢上每张照片计算时间
# 其他参数
flight_energy_factor : 0.05 # 单位:分钟/张
comp_energy_factor : 0.05 # TODO 计算能耗需要重新估计
trans_energy_factor : 0.0025
battery_energy_capacity : 20 # 无人机只进行飞行续航为30分钟

View File

@ -0,0 +1,73 @@
{
"row_boundaries": [
0.0,
0.2,
0.4,
0.6000000000000001,
0.8,
1.0
],
"col_boundaries": [
0,
0.12533333333333332,
0.25066666666666665,
0.376,
0.5013333333333333,
0.6266666666666666,
0.7519999999999999,
0.8773333333333332,
1
],
"car_paths": [
[
0,
1,
2,
3,
4,
5,
6,
7
],
[
8,
9,
10,
11,
12,
13,
14,
15
],
[
16,
17,
18,
19,
20,
21,
22,
23
],
[
24,
25,
26,
27,
28,
29,
30,
31
],
[
32,
33,
34,
35,
36,
37,
38,
39
]
]
}

View File

@ -0,0 +1,77 @@
{
"row_boundaries": [
0.0,
0.16666666666666666,
0.3333333333333333,
0.5,
0.6666666666666666,
0.8333333333333333,
1.0
],
"col_boundaries": [
0,
0.1504,
0.3008,
0.45120000000000005,
0.6016,
0.752,
0.9024,
1
],
"car_paths": [
[
0,
1,
2,
3,
4,
5,
6
],
[
7,
8,
9,
10,
11,
12,
13
],
[
14,
15,
16,
17,
18,
19,
20
],
[
21,
22,
23,
24,
25,
26,
27
],
[
28,
29,
30,
31,
32,
33,
34
],
[
35,
36,
37,
38,
39,
40,
41
]
]
}

View File

@ -0,0 +1,74 @@
{
"row_boundaries": [
0.0,
0.1,
0.2,
0.3,
0.4,
0.5,
0.6,
0.7,
0.8,
0.9,
1.0
],
"col_boundaries": [
0.0,
0.2,
0.4,
0.7,
1.0
],
"car_paths": [
[
17,
5,
4,
0,
1,
2,
6,
10,
14
],
[
18,
13,
9,
8,
12,
16,
20,
32,
21
],
[
22,
26,
30,
34,
39,
35,
31,
27,
19
],
[
25,
24,
28,
36,
33,
37,
38,
29
],
[
15,
23,
7,
11,
3
]
]
}

278
test_new_GA/GA_MTSP.py Normal file
View File

@ -0,0 +1,278 @@
import numpy as np
import random
import math
import matplotlib.pyplot as plt
import time
# 设置随机种子
np.random.seed(42)
random.seed(42)
class MTSP_GA:
def __init__(self, cities, params, rectangles, population_size=200, max_iterations=1500):
"""
初始化遗传算法求解器
Args:
cities: 城市坐标数组第一个城市为起始点
vehicle_num: 车辆数量
population_size: 种群大小
max_iterations: 最大迭代次数
"""
H = params['H']
W = params['W']
k = params['num_cars']
flight_time_factor = params['flight_time_factor']
comp_time_factor = params['comp_time_factor']
trans_time_factor = params['trans_time_factor']
car_time_factor = params['car_time_factor']
bs_time_factor = params['bs_time_factor']
flight_energy_factor = params['flight_energy_factor']
comp_energy_factor = params['comp_energy_factor']
trans_energy_factor = params['trans_energy_factor']
battery_energy_capacity = params['battery_energy_capacity']
self.cities = np.array(cities)
self.city_count = len(cities)
self.vehicle_num = k
self.origin = 0 # 起始点
self.rectangles = rectangles
self.car_time_factor = car_time_factor
self.H = H
self.W = W
# GA参数
self.population_size = population_size
self.max_iterations = max_iterations
self.retain_rate = 0.3 # 强者存活率
self.random_rate = 0.5 # 弱者存活概率
self.mutation_rate = 0.5 # 变异率
# 计算距离矩阵
self.distance_matrix = self._compute_distance_matrix()
# 记录收敛过程
self.distance_history = []
self.best_path_history = []
def _compute_distance_matrix(self):
"""计算城市间距离矩阵"""
distance = np.zeros((self.city_count, self.city_count))
for i in range(self.city_count):
for j in range(self.city_count):
distance[i][j] = math.sqrt(
(self.cities[i][0] - self.cities[j][0]) ** 2 +
(self.cities[i][1] - self.cities[j][1]) ** 2
)
return distance
def _create_individual(self):
"""生成初始个体"""
index = [i for i in range(self.city_count)]
index.remove(self.origin)
a = int(np.floor(len(index)/self.vehicle_num))
X = []
for i in range(self.vehicle_num):
if i < self.vehicle_num-1:
x = index[a*i:a*(i+1)]
else:
x = index[a*i:]
X.append(x)
return X
def _get_total_distance(self, X):
"""计算路径总距离"""
# 根据car_paths计算时间
for car_idx in range(self.vehicle_num):
car_path = X[car_idx]
flight_time = sum(self.rectangles[point - 1]['flight_time']
for point in car_path)
bs_time = sum(self.rectangles[point - 1]
['bs_time'] for point in car_path)
# 计算车的移动时间,首先在轨迹的首尾添加上大区域中心
car_time = 0
car_time += self.distance_matrix[self.origin][car_path[0]] * self.car_time_factor
car_time += self.distance_matrix[self.origin][car_path[-1]] * self.car_time_factor
for i in range(len(car_path) - 1):
first_point = car_path[i]
second_point = car_path[i + 1]
car_time += self.distance_matrix[first_point][second_point] * self.car_time_factor
system_time = max(flight_time + car_time, bs_time)
return system_time
def _selection(self, population):
"""选择操作"""
graded = [[self._get_total_distance(x), x] for x in population]
graded = [x[1] for x in sorted(graded)]
retain_length = int(len(graded) * self.retain_rate)
parents = graded[:retain_length]
for chromosome in graded[retain_length:]:
if random.random() < self.random_rate:
parents.append(chromosome)
return parents
def _crossover(self, parents):
"""交叉操作"""
target_count = self.population_size - len(parents)
children = []
while len(children) < target_count:
male_index = random.randint(0, len(parents) - 1)
female_index = random.randint(0, len(parents) - 1)
if male_index != female_index:
male = parents[male_index]
female = parents[female_index]
gene1 = []
gene2 = []
for i in range(len(male)):
gene1 += male[i]
gene2 += female[i]
left = random.randint(0, len(gene1)//2)
right = random.randint(left + 1, len(gene1))
cut = gene1[left:right]
copy = gene2.copy()
for j in cut:
copy.remove(j)
child = copy + cut
a = int(np.floor(len(child)/self.vehicle_num))
child_c = []
for i in range(self.vehicle_num):
if i < self.vehicle_num - 1:
x = child[a * i:a * (i + 1)]
else:
x = child[a * i:]
child_c.append(x)
children.append(child_c)
return children
def _mutation(self, children):
"""变异操作"""
for i in range(len(children)):
if random.random() < self.mutation_rate:
child = children[i]
for j in range(int(np.floor(len(child)/2))):
a = 2*j
u = random.randint(1, len(child[a]) - 1)
w = random.randint(1, len(child[a+1]) - 1)
child_1 = child[a][:u].copy()
child_2 = child[a][u:].copy()
child_3 = child[a+1][:w].copy()
child_4 = child[a+1][w:].copy()
child_a = child_1+child_3
child_b = child_2+child_4
child[a] = child_a
child[a+1] = child_b
children[i] = child.copy()
return children
def _get_best_solution(self, population):
"""获取最优解"""
graded = [[self._get_total_distance(x), x] for x in population]
graded = sorted(graded, key=lambda x: x[0])
return graded[0][0], graded[0][1]
def solve(self):
"""
求解MTSP加入早停机制
当连续50轮没有更好的解时停止迭代
"""
# 初始化种群
population = [self._create_individual()
for _ in range(self.population_size)]
# 初始化早停相关变量
best_distance = float('inf')
early_stop_counter = 0
early_stop_threshold = 100
# 迭代优化
for i in range(self.max_iterations):
parents = self._selection(population)
children = self._crossover(parents)
children = self._mutation(children)
population = parents + children
# 记录当前最优解
current_distance, current_path = self._get_best_solution(
population)
self.distance_history.append(current_distance)
self.best_path_history.append(current_path)
# 早停判断
if current_distance < best_distance:
best_distance = current_distance
best_path = current_path
# early_stop_counter = 0 # 重置计数器
# else:
# early_stop_counter += 1
# # 如果连续50轮没有更好的解则停止迭代
# if early_stop_counter >= early_stop_threshold:
# # print(
# # f"Early stopping at iteration {i} due to no improvement in {early_stop_threshold} iterations")
# break
# 返回最优解
return best_distance, best_path
def plot_convergence(self):
"""绘制收敛曲线"""
plt.plot(range(len(self.distance_history)), self.distance_history)
plt.xlabel('Iteration')
plt.ylabel('Total Distance')
plt.title('Convergence Curve')
plt.show()
def main():
# 城市坐标
cities = np.array([
(456, 320), # 起点(基地)
(228, 0),
(912, 0),
(0, 80),
(114, 80),
(570, 160),
(798, 160),
(342, 240),
(684, 240),
(570, 400),
(912, 400),
(114, 480),
(228, 480),
(342, 560),
(684, 560),
(0, 640),
(798, 640)
])
# 创建求解器实例
solver = MTSP_GA(
cities=cities,
vehicle_num=4,
population_size=200,
max_iterations=1500
)
# 求解
start_time = time.time()
best_distance, best_path = solver.solve()
end_time = time.time()
# 输出结果
print(f"最优总距离: {best_distance:.2f}")
print("最优路径方案:")
for i, path in enumerate(best_path):
print(f"车辆{i+1}的路径: {path}")
print(f"求解时间: {end_time - start_time:.2f}")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,132 @@
import random
import math
import yaml
import numpy as np
from utils import if_valid_partition
from itertools import product
import json
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor, as_completed
from GA_MTSP import MTSP_GA
def process_partition(row_cuts, col_cuts, params):
row_boundaries = [0.0] + list(row_cuts) + [1.0]
col_boundaries = [0.0] + list(col_cuts) + [1.0]
rectangles = if_valid_partition(row_boundaries, col_boundaries, params)
if not rectangles:
return None # 过滤无效划分
cities = [(params['H'] / 2.0, params['W'] / 2.0)]
for rec in rectangles:
cities.append(rec['center'])
sovler = MTSP_GA(
cities=cities, params=params, rectangles=rectangles, population_size=200, max_iterations=2000)
current_time, current_solution = sovler.solve()
return (current_solution, current_time, row_boundaries, col_boundaries, rectangles)
if __name__ == "__main__": # 重要:在 Windows 上必须加这一行
np.random.seed(42)
random.seed(42)
# ---------------------------
# 需要修改的超参数
# ---------------------------
R = 3
C = 1
params_file = 'params_50_50_3'
batch_size = 60 # 控制一次最多并行多少个任务
with open(params_file + '.yml', 'r', encoding='utf-8') as file:
params = yaml.safe_load(file)
H = params['H']
W = params['W']
k = params['num_cars']
flight_time_factor = params['flight_time_factor']
comp_time_factor = params['comp_time_factor']
trans_time_factor = params['trans_time_factor']
car_time_factor = params['car_time_factor']
bs_time_factor = params['bs_time_factor']
flight_energy_factor = params['flight_energy_factor']
comp_energy_factor = params['comp_energy_factor']
trans_energy_factor = params['trans_energy_factor']
battery_energy_capacity = params['battery_energy_capacity']
# 定义数字列表
numbers = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
# 生成所有的排列情况取三次每次都可以从10个数中选
row_product = list(product(numbers, repeat=R))
# 对每种情况从小到大排序,并剔除重复的情况
row_cuts_set = set(
tuple(sorted(set(item for item in prod if item > 0))) for prod in row_product)
row_cuts_set = sorted(row_cuts_set)
col_product = list(product(numbers, repeat=C))
col_cuts_set = set(
tuple(sorted(set(item for item in prod if item > 0))) for prod in col_product)
col_cuts_set = sorted(col_cuts_set)
best_T = float('inf')
best_solution = None
best_row_boundaries = None
best_col_boundaries = None
all_tasks = [(row_cuts, col_cuts)
for row_cuts in row_cuts_set for col_cuts in col_cuts_set]
total_iterations = len(all_tasks)
with ProcessPoolExecutor(max_workers=batch_size) as executor:
futures = set()
results = []
with tqdm(total=total_iterations) as pbar:
for task in all_tasks:
if len(futures) >= batch_size: # 如果并行任务数达到 batch_size等待已有任务完成
for future in as_completed(futures):
results.append(future.result())
pbar.update(1) # 更新进度条
futures.clear() # 清空已完成的任务
futures.add(executor.submit(
process_partition, *task, params)) # 提交新任务
# 处理剩余未完成的任务
for future in as_completed(futures):
results.append(future.result())
pbar.update(1)
# 处理计算结果,找到最优解
for result in results:
if result:
current_solution, current_time, row_boundaries, col_boundaries, rectangles = result
if current_time < best_T:
best_T = current_time
best_solution = current_solution
best_row_boundaries = row_boundaries
best_col_boundaries = col_boundaries
# 解析最佳路径
car_paths = [[x-1 for x in sublist]
for sublist in best_solution]
# 输出最佳方案
print("Best solution:", best_solution)
print("Time:", best_T)
print("Row boundaries:", best_row_boundaries)
print("Col boundaries:", best_col_boundaries)
print("Car Paths:", car_paths)
output_data = {
'row_boundaries': best_row_boundaries,
'col_boundaries': best_col_boundaries,
'car_paths': car_paths
}
with open(f'./solutions/trav_ga_{params_file}_parallel.json', 'w', encoding='utf-8') as file:
json.dump(output_data, file, ensure_ascii=False, indent=4)

58
test_new_GA/utils.py Normal file
View File

@ -0,0 +1,58 @@
import numpy as np
def if_valid_partition(row_boundaries, col_boundaries, params):
H = params['H']
W = params['W']
k = params['num_cars']
flight_time_factor = params['flight_time_factor']
comp_time_factor = params['comp_time_factor']
trans_time_factor = params['trans_time_factor']
car_time_factor = params['car_time_factor']
bs_time_factor = params['bs_time_factor']
flight_energy_factor = params['flight_energy_factor']
comp_energy_factor = params['comp_energy_factor']
trans_energy_factor = params['trans_energy_factor']
battery_energy_capacity = params['battery_energy_capacity']
# 根据分割边界生成所有矩形任务
rectangles = []
for i in range(len(row_boundaries) - 1):
for j in range(len(col_boundaries) - 1):
r1 = row_boundaries[i]
r2 = row_boundaries[i + 1]
c1 = col_boundaries[j]
c2 = col_boundaries[j + 1]
d = (r2 - r1) * H * (c2 - c1) * W # 任务的照片数量(矩形面积)
# 求解rho
rho_time_limit = (flight_time_factor - trans_time_factor) / \
(comp_time_factor - trans_time_factor)
rho_energy_limit = (battery_energy_capacity - flight_energy_factor * d - trans_energy_factor * d) / \
(comp_energy_factor * d - trans_energy_factor * d)
if rho_energy_limit < 0:
return []
rho = min(rho_time_limit, rho_energy_limit)
flight_time = flight_time_factor * d
comp_time = comp_time_factor * rho * d
trans_time = trans_time_factor * (1 - rho) * d
bs_time = bs_time_factor * (1 - rho) * d
# 计算任务矩形中心,用于后续车辆移动时间计算
center_r = (r1 + r2) / 2.0 * H
center_c = (c1 + c2) / 2.0 * W
rectangles.append({
# 'r1': r1, 'r2': r2, 'c1': c1, 'c2': c2,
'd': d,
'rho': rho,
'flight_time': flight_time,
'comp_time': comp_time,
'trans_time': trans_time,
'bs_time': bs_time,
'center': (center_r, center_c)
})
return rectangles

View File

@ -14,8 +14,8 @@ def visualize_solution(row_boundaries, col_boundaries, car_paths_coords, W, H, r
# 设置英文标题和标签 # 设置英文标题和标签
# ax.set_title("Monte Carlo", fontsize=12) # ax.set_title("Monte Carlo", fontsize=12)
# ax.set_title("Greedy", fontsize=12) ax.set_title("Greedy", fontsize=12)
ax.set_title("Enumeration-Genetic Algorithm", fontsize=12) # ax.set_title("Enumeration-Genetic Algorithm", fontsize=12)
# ax.set_title("DQN fine-tuning", fontsize=12) # ax.set_title("DQN fine-tuning", fontsize=12)
ax.set_xlabel("Region Width", fontsize=10) ax.set_xlabel("Region Width", fontsize=10)
@ -199,8 +199,8 @@ if __name__ == "__main__":
# --------------------------- # ---------------------------
# 需要修改的超参数 # 需要修改的超参数
# --------------------------- # ---------------------------
params_file = 'params_50_50_3' params_file = 'params_100_100_6'
solution_file = r'solutions\dqn_params_50_50_3_2.json' solution_file = r'solutions\greedy_params_100_100_6.json'
with open(params_file + '.yml', 'r', encoding='utf-8') as file: with open(params_file + '.yml', 'r', encoding='utf-8') as file:
params = yaml.safe_load(file) params = yaml.safe_load(file)