HPCC2025/visualization.py
weixin_46229132 981681c1bd 修改dqn bug
2025-04-01 20:45:13 +08:00

234 lines
8.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import matplotlib.pyplot as plt
import matplotlib.patches as patches
import json
import math
def visualize_solution(row_boundaries, col_boundaries, car_paths_coords, W, H, rho_list):
region_center = (H / 2.0, W / 2.0)
# 创建正方形图像
fig, ax = plt.subplots(figsize=(8, 8)) # 设置固定的正方形大小
ax.set_xlim(0, W)
ax.set_ylim(H, 0) # 调整y轴方向原点在左上角
# 设置英文标题和标签
# ax.set_title("Monte Carlo", fontsize=12)
# ax.set_title("Greedy", fontsize=12)
ax.set_title("Enumeration-Genetic Algorithm", fontsize=12)
# ax.set_title("DQN fine-tuning", fontsize=12)
ax.set_xlabel("Region Width", fontsize=10)
ax.set_ylabel("Region Height", fontsize=10)
# 定义配色方案(使用更专业的配色)
colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728',
'#9467bd', '#8c564b', '#e377c2', '#7f7f7f']
# 绘制行分割边界
for row in row_boundaries[1:-1]:
ax.axhline(y=row * H, color='gray', linestyle='--', alpha=0.5)
# 绘制列分割边界
for col in col_boundaries[1:-1]:
ax.axvline(x=col * W, color='gray', linestyle='--', alpha=0.5)
# 绘制每辆车的轨迹并标注区域序号
for system_id, path in enumerate(car_paths_coords):
path = [(region_center[0], region_center[1])] + \
path + [(region_center[0], region_center[1])]
y, x = zip(*path)
# 使用箭头绘制路径
for i in range(len(path)-1):
# 绘制带箭头的线段
ax.annotate('',
xy=(x[i+1], y[i+1]),
xytext=(x[i], y[i]),
arrowprops=dict(arrowstyle='->',
color=colors[int(system_id) % len(colors)],
lw=2,
mutation_scale=15),
zorder=1)
# 绘制路径点
ax.plot(x, y, 'o', markersize=6,
color=colors[int(system_id) % len(colors)],
label=f"System {system_id}",
zorder=2)
# 标注每个区域的序号(将序号向上偏移一点)
for idx, (px, py) in enumerate(zip(x[1:-1], y[1:-1])):
offset = H * 0.02 # 根据区域高度设置偏移量
ax.text(px, py - offset, str(idx),
color='black',
fontsize=9,
ha='center',
va='bottom',
bbox=dict(
facecolor='none',
edgecolor='none',
alpha=0.7,
pad=0.5))
# 绘制区域中心设置最高的zorder确保在最上层
ax.plot(region_center[1], region_center[0],
'k*', markersize=12, label="Region Center",
zorder=3)
# 添加图例
ax.legend(loc='upper right', fontsize=9)
# 保持坐标轴比例相等
ax.set_aspect('equal', adjustable='box')
# 调整布局,确保所有元素都显示完整
plt.tight_layout()
# 显示网格
ax.grid(True, linestyle=':', alpha=0.3)
# 在每个矩形区域左上角标注rho值
rho_idx = 0
for i in range(len(row_boundaries) - 1):
for j in range(len(col_boundaries) - 1):
# 获取矩形左上角坐标
x = col_boundaries[j] * W
y = row_boundaries[i] * H
# 添加一个小的偏移量,避免完全贴在边界上
offset_x = W * 0.02
offset_y = H * 0.02
# 标注rho值
ax.text(x + offset_x, y + offset_y,
f'ρ={rho_list[rho_idx]:.2f}',
color='black',
fontsize=8,
ha='left',
va='top',
bbox=dict(facecolor='white',
edgecolor='none',
alpha=0.7,
pad=0.5),
zorder=2)
rho_idx += 1
plt.show()
def restore_from_solution(row_boundaries, col_boundaries, car_paths, params):
H = params['H']
W = params['W']
k = params['num_cars']
flight_time_factor = params['flight_time_factor']
comp_time_factor = params['comp_time_factor']
trans_time_factor = params['trans_time_factor']
car_time_factor = params['car_time_factor']
bs_time_factor = params['bs_time_factor']
flight_energy_factor = params['flight_energy_factor']
comp_energy_factor = params['comp_energy_factor']
trans_energy_factor = params['trans_energy_factor']
battery_energy_capacity = params['battery_energy_capacity']
rectangles = []
for i in range(len(row_boundaries) - 1):
for j in range(len(col_boundaries) - 1):
r1 = row_boundaries[i]
r2 = row_boundaries[i + 1]
c1 = col_boundaries[j]
c2 = col_boundaries[j + 1]
d = (r2 - r1) * H * (c2 - c1) * W # 任务的照片数量(矩形面积)
# 求解rho
rho_time_limit = (flight_time_factor - trans_time_factor) / \
(comp_time_factor - trans_time_factor)
rho_energy_limit = (battery_energy_capacity - flight_energy_factor * d - trans_energy_factor * d) / \
(comp_energy_factor * d - trans_energy_factor * d)
rho = min(rho_time_limit, rho_energy_limit)
flight_time = flight_time_factor * d
comp_time = comp_time_factor * rho * d
trans_time = trans_time_factor * (1 - rho) * d
bs_time = bs_time_factor * (1 - rho) * d
# 计算任务矩形中心,用于后续车辆移动时间计算
center_r = (r1 + r2) / 2.0 * H
center_c = (c1 + c2) / 2.0 * W
rectangles.append({
'rho': rho,
'flight_time': flight_time,
'bs_time': bs_time,
'center': (center_r, center_c)
})
system_times = []
# 根据car_paths计算时间
for car_idx in range(k):
car_path = car_paths[car_idx]
flight_time = sum(rectangles[point]['flight_time']
for point in car_path)
bs_time = sum(rectangles[point]['bs_time'] for point in car_path)
# 计算车的移动时间,首先在轨迹的首尾添加上大区域中心
car_time = 0
for i in range(len(car_path) - 1):
first_point = car_path[i]
second_point = car_path[i + 1]
car_time += math.dist(rectangles[first_point]['center'], rectangles[second_point]['center']) * \
car_time_factor
car_time += math.dist(rectangles[car_path[0]]
['center'], [H / 2, W / 2]) * car_time_factor
car_time += math.dist(rectangles[car_path[-1]]
['center'], [H / 2, W / 2]) * car_time_factor
system_time = max(flight_time + car_time, bs_time)
system_times.append(system_time)
print(f"系统{car_idx}的总时间: {system_time}")
print(f"最终时间: {max(system_times)}")
rho_list = [rectangle['rho'] for rectangle in rectangles]
return rectangles, rho_list
if __name__ == "__main__":
import yaml
# ---------------------------
# 需要修改的超参数
# ---------------------------
params_file = 'params_50_50_3'
solution_file = r'solutions\finetune_params_50_50_3.json'
with open(params_file + '.yml', 'r', encoding='utf-8') as file:
params = yaml.safe_load(file)
H = params['H']
W = params['W']
k = params['num_cars']
# 读取最佳方案的JSON文件
with open(solution_file, 'r', encoding='utf-8') as f:
best_solution = json.load(f)
row_boundaries = best_solution['row_boundaries']
col_boundaries = best_solution['col_boundaries']
car_paths = best_solution['car_paths']
rectangles, rho_list = restore_from_solution(
row_boundaries, col_boundaries, car_paths, params)
# 计算分块区域的中心点坐标
rectangles_centers = [rectangle['center'] for rectangle in rectangles]
# 将car_paths里的index换成坐标
car_paths_coords = [[] for _ in range(k)]
for car_idx in range(k):
car_path = car_paths[car_idx]
for point in car_path:
car_paths_coords[car_idx].append(rectangles_centers[point])
visualize_solution(row_boundaries, col_boundaries, car_paths_coords, W, H, rho_list)