153 lines
5.4 KiB
Python
153 lines
5.4 KiB
Python
import numpy as np
|
|
import gym
|
|
from gym import spaces
|
|
|
|
class RoutingEnv(gym.Env):
|
|
"""路径规划环境(第三层)"""
|
|
def __init__(self, tasks):
|
|
super(RoutingEnv, self).__init__()
|
|
|
|
self.tasks = tasks # 任务列表
|
|
self.H = 20 # 区域高度
|
|
self.W = 25 # 区域宽度
|
|
self.region_center = (self.H/2, self.W/2)
|
|
|
|
# 时间系数
|
|
self.flight_time_factor = 3 # 每张照片飞行时间
|
|
self.comp_uav_factor = 5 # 无人机计算时间
|
|
self.trans_time_factor = 0.3 # 传输时间
|
|
self.car_move_time_factor = 100 # 汽车移动时间
|
|
self.comp_bs_factor = 5 # 机巢计算时间
|
|
|
|
# 动作空间:选择下一个要访问的任务索引
|
|
self.action_space = spaces.Discrete(len(tasks))
|
|
|
|
# 状态空间:[当前位置x, 当前位置y, 未访问任务的mask]
|
|
self.observation_space = spaces.Box(
|
|
low=np.array([0, 0] + [0] * len(tasks)),
|
|
high=np.array([self.H, self.W] + [1] * len(tasks)),
|
|
dtype=np.float32
|
|
)
|
|
|
|
self.state = None
|
|
self.current_position = self.region_center
|
|
self.unvisited_mask = np.ones(len(tasks))
|
|
self.total_flight_time = 0
|
|
|
|
def calculate_task_time(self, task):
|
|
"""计算单个任务的执行时间"""
|
|
area = task['area']
|
|
|
|
# 计算最优卸载率
|
|
rho_time_limit = (self.flight_time_factor - self.trans_time_factor) / \
|
|
(self.comp_uav_factor - self.trans_time_factor)
|
|
rho_energy_limit = (30 - self.flight_time_factor * area - self.trans_time_factor * area) / \
|
|
(self.comp_uav_factor * area - self.trans_time_factor * area)
|
|
if rho_energy_limit < 0:
|
|
return None, None
|
|
rho = min(rho_time_limit, rho_energy_limit)
|
|
|
|
# 计算各阶段时间
|
|
flight_time = self.flight_time_factor * area
|
|
comp_time = self.comp_uav_factor * rho * area
|
|
trans_time = self.trans_time_factor * (1 - rho) * area
|
|
comp_bs_time = self.comp_bs_factor * (1 - rho) * area
|
|
|
|
task_time = max(flight_time, comp_bs_time)
|
|
return task_time, rho
|
|
|
|
def calculate_move_time(self, from_pos, to_pos):
|
|
"""计算移动时间"""
|
|
dist = np.sqrt((from_pos[0] - to_pos[0])**2 + (from_pos[1] - to_pos[1])**2)
|
|
return dist * self.car_move_time_factor
|
|
|
|
def step(self, action):
|
|
# 检查动作是否有效
|
|
if self.unvisited_mask[action] == 0:
|
|
return self.state, -10000, True, {} # 惩罚选择已访问的任务
|
|
|
|
# 获取选中的任务
|
|
task = self.tasks[action]
|
|
task_center = task['center']
|
|
|
|
# 计算移动时间
|
|
move_time = self.calculate_move_time(self.current_position, task_center)
|
|
|
|
# 计算任务执行时间
|
|
task_time, rho = self.calculate_task_time(task)
|
|
if task_time is None: # 任务不可行
|
|
return self.state, -10000, True, {}
|
|
|
|
# 更新状态
|
|
self.current_position = task_center
|
|
self.unvisited_mask[action] = 0
|
|
self.total_flight_time += task_time
|
|
|
|
# 构建新状态
|
|
self.state = np.concatenate([
|
|
np.array(self.current_position),
|
|
self.unvisited_mask
|
|
])
|
|
|
|
# 检查是否所有任务都已完成
|
|
done = np.sum(self.unvisited_mask) == 0
|
|
|
|
# 计算奖励(负的总时间)
|
|
total_time = max(self.total_flight_time, move_time)
|
|
reward = -total_time if done else -move_time
|
|
|
|
return self.state, reward, done, {}
|
|
|
|
def reset(self):
|
|
self.current_position = self.region_center
|
|
self.unvisited_mask = np.ones(len(self.tasks))
|
|
self.total_flight_time = 0
|
|
|
|
self.state = np.concatenate([
|
|
np.array(self.current_position),
|
|
self.unvisited_mask
|
|
])
|
|
return self.state
|
|
|
|
def render(self, mode='human'):
|
|
pass
|
|
|
|
def optimize(self):
|
|
"""使用DQN优化路径规划"""
|
|
from dqn import Agent
|
|
|
|
state_dim = self.observation_space.shape[0]
|
|
action_dim = len(self.tasks)
|
|
|
|
agent = Agent(state_dim, action_dim)
|
|
|
|
# 训练参数
|
|
episodes = 50 # 进一步减少训练轮数,因为这是最底层子问题
|
|
max_steps = len(self.tasks) + 1 # 最多访问所有任务+返回
|
|
|
|
best_reward = float('-inf')
|
|
best_time = float('inf')
|
|
valid_solution = False
|
|
|
|
for episode in range(episodes):
|
|
state = self.reset()
|
|
episode_reward = 0
|
|
|
|
for step in range(max_steps):
|
|
action = agent.choose_action(state)
|
|
next_state, reward, done, _ = self.step(action)
|
|
|
|
agent.store_transition(state, action, reward, next_state, done)
|
|
agent.learn()
|
|
|
|
episode_reward += reward
|
|
state = next_state
|
|
|
|
if done:
|
|
if reward != -10000: # 如果是有效解
|
|
valid_solution = True
|
|
best_time = min(best_time, -reward)
|
|
break
|
|
|
|
return best_time, valid_solution
|