修改GA bug
This commit is contained in:
parent
17acfa5409
commit
8e8d9a25df
76
GA/ga.py
76
GA/ga.py
@ -4,8 +4,6 @@ import random
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
import plot_util
|
||||
|
||||
|
||||
class GA(object):
|
||||
def __init__(self, num_drones, num_city, num_total, data, to_process_idx, rectangles):
|
||||
@ -30,7 +28,6 @@ class GA(object):
|
||||
init_best = self.fruits[sort_index[0]]
|
||||
init_best = self.location[init_best]
|
||||
|
||||
|
||||
# 存储每个iteration的结果,画出收敛图
|
||||
self.iter_x = [0]
|
||||
self.iter_y = [1.0 / scores[sort_index[0]]]
|
||||
@ -164,7 +161,8 @@ class GA(object):
|
||||
if point in self.to_process_idx:
|
||||
continue
|
||||
else:
|
||||
flight_time += self.rectangles[point - 1]['flight_time'] # 注意,这里要减一!!!
|
||||
# 注意,这里要减一!!!
|
||||
flight_time += self.rectangles[point - 1]['flight_time']
|
||||
bs_time += self.rectangles[point - 1]['bs_time']
|
||||
system_time = max(flight_time + car_info['car_time'], bs_time)
|
||||
T_k_list.append(system_time)
|
||||
@ -323,47 +321,47 @@ class GA(object):
|
||||
# print(1.0 / best_score)
|
||||
return tmp_best_one, 1.0 / best_score
|
||||
|
||||
if __name__ == '__main__':
|
||||
seed = 42
|
||||
num_drones = 6
|
||||
num_city = 12
|
||||
epochs = 3000
|
||||
# if __name__ == '__main__':
|
||||
# seed = 42
|
||||
# num_drones = 6
|
||||
# num_city = 12
|
||||
# epochs = 3000
|
||||
|
||||
# 固定随机数
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
# # 固定随机数
|
||||
# np.random.seed(seed)
|
||||
# random.seed(seed)
|
||||
|
||||
|
||||
## 初始化坐标 (第一个点是基地的起点,起点的坐标是 0,0 )
|
||||
data = [[0, 0]]
|
||||
for i in range(num_city - 1):
|
||||
while True:
|
||||
x = np.random.randint(-250, 250)
|
||||
y = np.random.randint(-250, 250)
|
||||
if x != 0 or y != 0:
|
||||
break
|
||||
data.append([x, y])
|
||||
# ## 初始化坐标 (第一个点是基地的起点,起点的坐标是 0,0 )
|
||||
# data = [[0, 0]]
|
||||
# for i in range(num_city - 1):
|
||||
# while True:
|
||||
# x = np.random.randint(-250, 250)
|
||||
# y = np.random.randint(-250, 250)
|
||||
# if x != 0 or y != 0:
|
||||
# break
|
||||
# data.append([x, y])
|
||||
|
||||
data = np.array(data)
|
||||
# data = np.array(data)
|
||||
|
||||
# 关键:有N架无人机,则再增加N-1个`点` (坐标是起始点),这些点之间的距离是inf
|
||||
for d in range(num_drones - 1):
|
||||
data = np.vstack([data, data[0]])
|
||||
num_city += 1 # 增加欺骗城市
|
||||
# # 关键:有N架无人机,则再增加N-1个`点` (坐标是起始点),这些点之间的距离是inf
|
||||
# for d in range(num_drones - 1):
|
||||
# data = np.vstack([data, data[0]])
|
||||
# num_city += 1 # 增加欺骗城市
|
||||
|
||||
to_process_idx = [0]
|
||||
# print("start point:", location[0])
|
||||
for d in range(1, num_drones): # 1, ... drone-1
|
||||
# print("added base point:", location[num_city - d])
|
||||
to_process_idx.append(num_city - d)
|
||||
|
||||
model = GA(num_city=data.shape[0], num_total=20, data=data.copy())
|
||||
Best_path, Best = model.run()
|
||||
print(Best_path)
|
||||
iterations = model.iter_x
|
||||
best_record = model.iter_y
|
||||
# to_process_idx = [0]
|
||||
# # print("start point:", location[0])
|
||||
# for d in range(1, num_drones): # 1, ... drone-1
|
||||
# # print("added base point:", location[num_city - d])
|
||||
# to_process_idx.append(num_city - d)
|
||||
|
||||
# model = GA(num_city=data.shape[0], num_total=20, data=data.copy())
|
||||
# Best_path, Best = model.run()
|
||||
# print(Best_path)
|
||||
# iterations = model.iter_x
|
||||
# best_record = model.iter_y
|
||||
|
||||
print(f"Best Path Length: {Best:.3f}")
|
||||
plot_util.plot_results(Best_path, iterations, best_record)
|
||||
# # print(Best_path)
|
||||
|
||||
# print(f"Best Path Length: {Best:.3f}")
|
||||
# plot_util.plot_results(Best_path, iterations, best_record)
|
||||
|
@ -4,7 +4,6 @@ import matplotlib.pyplot as plt
|
||||
import matplotlib.patches as patches
|
||||
import numpy as np
|
||||
from ga import GA
|
||||
import plot_util
|
||||
|
||||
# 固定随机种子,便于复现
|
||||
random.seed(42)
|
||||
|
24
GA/main.py
24
GA/main.py
@ -14,8 +14,8 @@ best_solution = None
|
||||
best_row_boundaries = None
|
||||
best_col_boundaries = None
|
||||
|
||||
params_file = 'params2.yml'
|
||||
with open(params_file, 'r', encoding='utf-8') as file:
|
||||
params_file = 'params3'
|
||||
with open(params_file + 'yml', 'r', encoding='utf-8') as file:
|
||||
params = yaml.safe_load(file)
|
||||
|
||||
H = params['H']
|
||||
@ -37,13 +37,13 @@ battery_energy_capacity = params['battery_energy_capacity']
|
||||
numbers = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
|
||||
|
||||
# 生成所有的排列情况(取三次,每次都可以从10个数中选)
|
||||
row_product = list(product(numbers, repeat=3))
|
||||
row_product = list(product(numbers, repeat=1))
|
||||
# 对每种情况从小到大排序,并剔除重复的情况
|
||||
row_cuts_set = set(
|
||||
tuple(sorted(set(item for item in prod if item > 0))) for prod in row_product)
|
||||
row_cuts_set = sorted(row_cuts_set)
|
||||
|
||||
col_product = list(product(numbers, repeat=3))
|
||||
col_product = list(product(numbers, repeat=1))
|
||||
col_cuts_set = set(
|
||||
tuple(sorted(set(item for item in prod if item > 0))) for prod in col_product)
|
||||
col_cuts_set = sorted(col_cuts_set)
|
||||
@ -56,13 +56,15 @@ with tqdm(total=total_iterations, desc="Processing") as pbar:
|
||||
col_boundaries = [0.0] + list(col_cuts) + [1.0]
|
||||
|
||||
# 这里面的距离不再是比例,而是真实距离!
|
||||
rectrangles = if_valid_partition(row_boundaries, col_boundaries, params)
|
||||
rectrangles = if_valid_partition(
|
||||
row_boundaries, col_boundaries, params)
|
||||
if not rectrangles:
|
||||
pbar.update(1)
|
||||
continue
|
||||
else:
|
||||
# 使用遗传算法求出每一种网格划分的可行解,然后选择其中的最优解
|
||||
current_solution, current_time, to_process_idx = GA_solver(rectrangles, k)
|
||||
current_solution, current_time, to_process_idx = GA_solver(
|
||||
rectrangles, params)
|
||||
|
||||
if current_time < best_T:
|
||||
best_T = current_time
|
||||
@ -83,7 +85,9 @@ with tqdm(total=total_iterations, desc="Processing") as pbar:
|
||||
for k in range(from_index, end_index + 1):
|
||||
rectrangle_idx = best_solution[k]
|
||||
if rectrangle_idx not in to_process_idx:
|
||||
car_path.append(rectrangles[rectrangle_idx - 1]['center'])
|
||||
car_path.append(
|
||||
rectrangles[rectrangle_idx - 1]['center'])
|
||||
if car_path:
|
||||
car_paths.append(car_path)
|
||||
pbar.update(1)
|
||||
|
||||
@ -94,9 +98,9 @@ print("Row boundaries:", best_row_boundaries)
|
||||
print("Col boundaries:", best_col_boundaries)
|
||||
|
||||
output_data = {
|
||||
'row_boundaries': row_boundaries,
|
||||
'col_boundaries': col_boundaries,
|
||||
'row_boundaries': best_row_boundaries,
|
||||
'col_boundaries': best_col_boundaries,
|
||||
'car_paths': car_paths
|
||||
}
|
||||
with open(f'./solutions/traverse_ga_{params_file}.json', 'w', encoding='utf-8') as file:
|
||||
with open(f'./solutions/trav_ga_{params_file}.json', 'w', encoding='utf-8') as file:
|
||||
json.dump(output_data, file, ensure_ascii=False, indent=4)
|
@ -8,12 +8,27 @@ import json
|
||||
from tqdm import tqdm
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
|
||||
|
||||
def process_partition(row_cuts, col_cuts, params):
|
||||
row_boundaries = [0.0] + list(row_cuts) + [1.0]
|
||||
col_boundaries = [0.0] + list(col_cuts) + [1.0]
|
||||
rectrangles = if_valid_partition(row_boundaries, col_boundaries, params)
|
||||
|
||||
if not rectrangles:
|
||||
return None # 过滤无效划分
|
||||
|
||||
current_solution, current_time, to_process_idx = GA_solver(
|
||||
rectrangles, params)
|
||||
|
||||
return (current_solution, current_time, row_boundaries, col_boundaries, to_process_idx, rectrangles)
|
||||
|
||||
|
||||
if __name__ == "__main__": # 重要:在 Windows 上必须加这一行
|
||||
np.random.seed(42)
|
||||
random.seed(42)
|
||||
|
||||
|
||||
params_file = 'params2.yml'
|
||||
with open(params_file, 'r', encoding='utf-8') as file:
|
||||
params_file = 'params3'
|
||||
with open(params_file + '.yml', 'r', encoding='utf-8') as file:
|
||||
params = yaml.safe_load(file)
|
||||
|
||||
H = params['H']
|
||||
@ -35,39 +50,25 @@ battery_energy_capacity = params['battery_energy_capacity']
|
||||
numbers = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
|
||||
|
||||
# 生成所有的排列情况(取三次,每次都可以从10个数中选)
|
||||
row_product = list(product(numbers, repeat=3))
|
||||
row_product = list(product(numbers, repeat=1))
|
||||
# 对每种情况从小到大排序,并剔除重复的情况
|
||||
row_cuts_set = set(
|
||||
tuple(sorted(set(item for item in prod if item > 0))) for prod in row_product)
|
||||
row_cuts_set = sorted(row_cuts_set)
|
||||
|
||||
col_product = list(product(numbers, repeat=3))
|
||||
col_product = list(product(numbers, repeat=1))
|
||||
col_cuts_set = set(
|
||||
tuple(sorted(set(item for item in prod if item > 0))) for prod in col_product)
|
||||
col_cuts_set = sorted(col_cuts_set)
|
||||
|
||||
|
||||
def process_partition(row_cuts, col_cuts):
|
||||
row_boundaries = [0.0] + list(row_cuts) + [1.0]
|
||||
col_boundaries = [0.0] + list(col_cuts) + [1.0]
|
||||
rectrangles = if_valid_partition(row_boundaries, col_boundaries, params)
|
||||
|
||||
if not rectrangles:
|
||||
return None # 过滤无效划分
|
||||
|
||||
current_solution, current_time, to_process_idx = GA_solver(rectrangles, k)
|
||||
|
||||
return (current_solution, current_time, row_boundaries, col_boundaries, to_process_idx, rectrangles)
|
||||
|
||||
|
||||
if __name__ == "__main__": # 重要:在 Windows 上必须加这一行
|
||||
best_T = float('inf')
|
||||
best_solution = None
|
||||
best_row_boundaries = None
|
||||
best_col_boundaries = None
|
||||
batch_size = 60 # 控制一次最多并行多少个任务
|
||||
|
||||
all_tasks = [(row_cuts, col_cuts) for row_cuts in row_cuts_set for col_cuts in col_cuts_set]
|
||||
all_tasks = [(row_cuts, col_cuts)
|
||||
for row_cuts in row_cuts_set for col_cuts in col_cuts_set]
|
||||
total_iterations = len(all_tasks)
|
||||
|
||||
with ProcessPoolExecutor(max_workers=batch_size) as executor:
|
||||
@ -82,7 +83,7 @@ if __name__ == "__main__": # 重要:在 Windows 上必须加这一行
|
||||
pbar.update(1) # 更新进度条
|
||||
futures.clear() # 清空已完成的任务
|
||||
|
||||
futures.add(executor.submit(process_partition, *task)) # 提交新任务
|
||||
futures.add(executor.submit(process_partition, *task, params)) # 提交新任务
|
||||
|
||||
# 处理剩余未完成的任务
|
||||
for future in as_completed(futures):
|
||||
@ -112,7 +113,9 @@ if __name__ == "__main__": # 重要:在 Windows 上必须加这一行
|
||||
for k in range(from_index, end_index + 1):
|
||||
rectrangle_idx = best_solution[k]
|
||||
if rectrangle_idx not in to_process_idx:
|
||||
car_path.append(rectrangles[rectrangle_idx - 1]['center'])
|
||||
car_path.append(
|
||||
rectrangles[rectrangle_idx - 1]['center'])
|
||||
if car_path:
|
||||
car_paths.append(car_path)
|
||||
|
||||
# 输出最佳方案
|
||||
@ -122,9 +125,9 @@ if __name__ == "__main__": # 重要:在 Windows 上必须加这一行
|
||||
print("Col boundaries:", best_col_boundaries)
|
||||
|
||||
output_data = {
|
||||
'row_boundaries': row_boundaries,
|
||||
'col_boundaries': col_boundaries,
|
||||
'row_boundaries': best_row_boundaries,
|
||||
'col_boundaries': best_col_boundaries,
|
||||
'car_paths': car_paths
|
||||
}
|
||||
with open(f'./solutions/travse_ga_{params_file}.json', 'w', encoding='utf-8') as file:
|
||||
with open(f'./solutions/trav_ga_{params_file}_parallel.json', 'w', encoding='utf-8') as file:
|
||||
json.dump(output_data, file, ensure_ascii=False, indent=4)
|
||||
|
@ -1,93 +0,0 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
# from matplotlib import colors as mcolors
|
||||
|
||||
# matplotlib_colors = list(dict(mcolors.BASE_COLORS, **mcolors.CSS4_COLORS).keys())
|
||||
|
||||
matplotlib_colors = [
|
||||
"black",
|
||||
"red",
|
||||
"yellow",
|
||||
"grey",
|
||||
"brown",
|
||||
"darkred",
|
||||
"peru",
|
||||
"darkorange",
|
||||
"darkkhaki",
|
||||
"steelblue",
|
||||
"blue",
|
||||
"cyan",
|
||||
"green",
|
||||
"navajowhite",
|
||||
"lightgrey",
|
||||
"lightcoral",
|
||||
"mediumblue",
|
||||
"midnightblue",
|
||||
"blueviolet",
|
||||
"violet",
|
||||
"fuchsia",
|
||||
"mediumvioletred",
|
||||
"hotpink",
|
||||
"crimson",
|
||||
"lightpink",
|
||||
"slategray",
|
||||
"lime",
|
||||
"springgreen",
|
||||
"teal",
|
||||
"beige",
|
||||
"olive",
|
||||
]
|
||||
|
||||
|
||||
def find_indices(list_to_check, item_to_find):
|
||||
indices = []
|
||||
for idx, value in enumerate(list_to_check):
|
||||
if np.array_equal(value, item_to_find):
|
||||
indices.append(idx)
|
||||
return indices
|
||||
|
||||
|
||||
def plot_results(Best_path, iterations, best_record):
|
||||
# print(find_indices(Best_path, [0, 0]))
|
||||
|
||||
# Best_path = np.vstack([Best_path, Best_path[0]])
|
||||
# Best_path = np.vstack([Best_path[0], Best_path])
|
||||
# print(Best_path[0], Best_path[-1])
|
||||
|
||||
if not np.array_equal(Best_path[0], [0, 0]):
|
||||
Best_path = np.vstack([[0, 0], Best_path])
|
||||
if not np.array_equal(Best_path[-1], [0, 0]):
|
||||
Best_path = np.vstack([Best_path, [0, 0]])
|
||||
# print(Best_path)
|
||||
|
||||
found_start_points_indices = find_indices(Best_path, [0, 0])
|
||||
result_paths = []
|
||||
|
||||
for j in range(len(found_start_points_indices) - 1):
|
||||
from_index = found_start_points_indices[j]
|
||||
end_index = found_start_points_indices[j + 1]
|
||||
path = []
|
||||
for k in range(from_index, end_index + 1):
|
||||
path.append(Best_path[k])
|
||||
path = np.array(path)
|
||||
result_paths.append(path)
|
||||
|
||||
# print(Best_path)
|
||||
# print(result_paths)
|
||||
|
||||
fig, axs = plt.subplots(1, 2, sharex=False, sharey=False)
|
||||
axs[0].scatter(Best_path[:, 0], Best_path[:, 1])
|
||||
|
||||
for ix, path in enumerate(result_paths):
|
||||
axs[0].plot(path[:, 0], path[:, 1], color=matplotlib_colors[ix], alpha=0.8)
|
||||
# axs[0].plot(Best_path[:, 0], Best_path[:, 1], color="green", alpha=0.1)
|
||||
|
||||
# Draw start point
|
||||
axs[0].plot([0], [0], marker="*", markersize=20, color="red")
|
||||
|
||||
axs[0].set_title("Searched Best Solution")
|
||||
|
||||
axs[1].plot(iterations, best_record)
|
||||
axs[1].set_title("Convergence Curve")
|
||||
plt.show()
|
1
GA/readme.md
Normal file
1
GA/readme.md
Normal file
@ -0,0 +1 @@
|
||||
参考代码`https://github.com/yhzhu99/mtsp`
|
@ -59,11 +59,12 @@ def if_valid_partition(row_boundaries, col_boundaries, params):
|
||||
return rectangles
|
||||
|
||||
|
||||
def GA_solver(rectangles, k):
|
||||
def GA_solver(rectangles, params):
|
||||
num_city = len(rectangles) + 1 # 划分好的区域中心点+整个区域的中心
|
||||
k = params['num_cars']
|
||||
|
||||
# 初始化坐标 (第一个点是整个区域的中心)
|
||||
center_data = [[1 / 2.0, 1 / 2.0]]
|
||||
center_data = [[params['H'] / 2.0, params['W'] / 2.0]]
|
||||
for rec in rectangles:
|
||||
center_data.append(rec['center'])
|
||||
center_data = np.array(center_data)
|
||||
@ -79,7 +80,7 @@ def GA_solver(rectangles, k):
|
||||
# print("added base point:", location[num_city - d])
|
||||
to_process_idx.append(num_city - d)
|
||||
|
||||
model = GA(num_drones=k, num_city=center_data.shape[0], num_total=20, data=center_data.copy(
|
||||
model = GA(num_drones=k, num_city=num_city, num_total=20, data=center_data.copy(
|
||||
), to_process_idx=to_process_idx, rectangles=rectangles)
|
||||
Best_path, Best = model.run()
|
||||
|
||||
|
16
params3.yml
Normal file
16
params3.yml
Normal file
@ -0,0 +1,16 @@
|
||||
H : 30 # 区域高度,网格点之间的距离为25m(单位距离)
|
||||
W : 30 # 区域宽度
|
||||
num_cars : 2 # 系统数量(车-巢-机系统个数)
|
||||
|
||||
# 时间系数(单位:秒,每个网格一张照片)
|
||||
flight_time_factor : 3 # 每张照片对应的飞行时间,无人机飞行速度为9.5m/s,拍摄照片的时间间隔为3s
|
||||
comp_time_factor : 5 # 无人机上每张照片计算时间,5s
|
||||
trans_time_factor : 0.3 # 每张照片传输时间,0.3s
|
||||
car_time_factor : 100 # TODO 汽车每单位距离的移动时间,2s,加了一个放大因子50
|
||||
bs_time_factor : 5 # 机巢上每张照片计算时间
|
||||
|
||||
# 其他参数
|
||||
flight_energy_factor : 0.05 # 单位:分钟/张
|
||||
comp_energy_factor : 0.05 # TODO 计算能耗需要重新估计
|
||||
trans_energy_factor : 0.0025
|
||||
battery_energy_capacity : 20 # 无人机只进行飞行,续航为30分钟
|
@ -7,7 +7,7 @@ def visualize_solution(row_boundaries, col_boundaries, car_paths, W, H):
|
||||
plt.rcParams['font.sans-serif'] = ['SimHei']
|
||||
fig, ax = plt.subplots()
|
||||
ax.set_xlim(0, W)
|
||||
ax.set_ylim(0, H)
|
||||
ax.set_ylim(H, 0) # 调整y轴方向,原点在左上角
|
||||
ax.set_title("区域划分与车-机-巢系统覆盖")
|
||||
ax.set_xlabel("区域宽度")
|
||||
ax.set_ylabel("区域高度")
|
||||
@ -16,7 +16,7 @@ def visualize_solution(row_boundaries, col_boundaries, car_paths, W, H):
|
||||
colors = ['red', 'blue', 'green', 'orange', 'purple', 'cyan', 'magenta']
|
||||
|
||||
# 绘制区域中心
|
||||
region_center = (H / 2.0, W / 2.0) # 注意:x对应宽度,y对应高度
|
||||
region_center = (H / 2.0, W / 2.0)
|
||||
ax.plot(region_center[1], region_center[0],
|
||||
'ko', markersize=8, label="区域中心")
|
||||
|
||||
@ -36,22 +36,20 @@ def visualize_solution(row_boundaries, col_boundaries, car_paths, W, H):
|
||||
|
||||
# 添加图例
|
||||
ax.legend()
|
||||
# 反转 y 轴使得行号从上到下递增(如需,可取消)
|
||||
ax.invert_yaxis()
|
||||
plt.show()
|
||||
|
||||
if __name__ == "__main__":
|
||||
import yaml
|
||||
|
||||
# 读取参数
|
||||
with open('params.yml', 'r', encoding='utf-8') as file:
|
||||
with open('params3.yml', 'r', encoding='utf-8') as file:
|
||||
params = yaml.safe_load(file)
|
||||
|
||||
H = params['H']
|
||||
W = params['W']
|
||||
|
||||
# 读取最佳方案的JSON文件
|
||||
with open('./solutions/best_solution_mtkl.json', 'r', encoding='utf-8') as f:
|
||||
with open(r'solutions\trav_ga_params3_parallel.json', 'r', encoding='utf-8') as f:
|
||||
best_solution = json.load(f)
|
||||
|
||||
row_boundaries = best_solution['row_boundaries']
|
||||
|
Loading…
Reference in New Issue
Block a user