107 lines
3.4 KiB
Python
107 lines
3.4 KiB
Python
import math
|
||
import yaml
|
||
import matplotlib.pyplot as plt
|
||
import matplotlib.patches as patches
|
||
import numpy as np
|
||
import json
|
||
|
||
|
||
def calculate_max_photos_per_flight(params):
|
||
"""计算每次飞行能拍摄的最大照片数量
|
||
基于以下约束:
|
||
1. 电池能量约束
|
||
2. 计算+传输时间 = 飞行时间
|
||
"""
|
||
# 从参数中提取时间和能量因子
|
||
flight_time_factor = params['flight_time_factor']
|
||
comp_time_factor = params['comp_time_factor']
|
||
trans_time_factor = params['trans_time_factor']
|
||
battery_energy_capacity = params['battery_energy_capacity']
|
||
flight_energy_factor = params['flight_energy_factor']
|
||
comp_energy_factor = params['comp_energy_factor']
|
||
trans_energy_factor = params['trans_energy_factor']
|
||
|
||
# 基于时间约束求解rho:飞行时间 = 计算时间 + 传输时间
|
||
# flight_time_factor * d = comp_time_factor * rho * d + trans_time_factor * (1-rho) * d
|
||
rho_time = (flight_time_factor - trans_time_factor) / \
|
||
(comp_time_factor - trans_time_factor)
|
||
|
||
# 基于能量约束求解最大照片数d
|
||
# battery_energy_capacity = flight_energy_factor * d + comp_energy_factor * rho * d + trans_energy_factor * (1-rho) * d
|
||
energy_per_photo = (flight_energy_factor +
|
||
comp_energy_factor * rho_time +
|
||
trans_energy_factor * (1 - rho_time))
|
||
|
||
max_photos = battery_energy_capacity / energy_per_photo
|
||
|
||
return max_photos
|
||
|
||
|
||
def solve_greedy(params):
|
||
"""使用贪心算法求解任务分配问题"""
|
||
H = params['H']
|
||
W = params['W']
|
||
k = params['num_cars'] # 车辆数量
|
||
|
||
# 1. 首先将区域均匀切分成k行
|
||
row_ratio = 1 / k
|
||
row_boundaries = [i * row_ratio for i in range(k + 1)]
|
||
|
||
# 2. 计算每次飞行能拍摄的最大照片数量
|
||
photos_per_flight = calculate_max_photos_per_flight(params)
|
||
print(f"每次飞行能拍摄的最大照片数: {photos_per_flight}")
|
||
|
||
# 3. 针对每一行计算网格划分
|
||
row_start = row_boundaries[0]
|
||
row_end = row_boundaries[1]
|
||
row_height = (row_end - row_start) * H
|
||
|
||
# 计算每个网格的宽度
|
||
# 网格面积 = row_height * grid_width = photos_per_flight
|
||
grid_width = photos_per_flight / row_height
|
||
col_ratio = grid_width / W
|
||
|
||
col_boundaries = [0]
|
||
ratio = 0
|
||
while (ratio + col_ratio) < 1:
|
||
ratio += col_ratio
|
||
col_boundaries.append(ratio)
|
||
col_boundaries.append(1)
|
||
|
||
car_paths = []
|
||
for i in range(k):
|
||
car_path = list(range(i * (len(col_boundaries) - 1),
|
||
(i+1) * (len(col_boundaries) - 1)))
|
||
car_paths.append(car_path)
|
||
|
||
return row_boundaries, col_boundaries, car_paths
|
||
|
||
|
||
def main():
|
||
# ---------------------------
|
||
# 需要修改的超参数
|
||
# ---------------------------
|
||
params_file = 'params_100_100_6'
|
||
|
||
# 读取参数
|
||
with open(params_file + '.yml', 'r', encoding='utf-8') as file:
|
||
params = yaml.safe_load(file)
|
||
|
||
# 求解
|
||
row_boundaries, col_boundaries, car_paths = solve_greedy(params)
|
||
|
||
# ---------------------------
|
||
# 输出最佳方案
|
||
# ---------------------------
|
||
output_data = {
|
||
'row_boundaries': row_boundaries,
|
||
'col_boundaries': col_boundaries,
|
||
'car_paths': car_paths
|
||
}
|
||
with open(f'./solutions/greedy_{params_file}.json', 'w', encoding='utf-8') as f:
|
||
json.dump(output_data, f, ensure_ascii=False, indent=4)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|