添加q-learning TSP
This commit is contained in:
parent
1485fb2bd6
commit
a375832b6c
@ -1,22 +1,22 @@
|
|||||||
%matplotlib inline
|
|
||||||
import pylab as plt
|
import pylab as plt
|
||||||
from IPython.display import clear_output
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
|
|
||||||
class TSP(object):
|
class TSP(object):
|
||||||
'''
|
'''
|
||||||
用 Q-Learning 求解 TSP 问题
|
用 Q-Learning 求解 TSP 问题
|
||||||
作者 Surfer Zen @ https://www.zhihu.com/people/surfer-zen
|
作者 Surfer Zen @ https://www.zhihu.com/people/surfer-zen
|
||||||
'''
|
'''
|
||||||
def __init__(self,
|
|
||||||
num_cities=15,
|
def __init__(self,
|
||||||
map_size=(800.0, 600.0),
|
num_cities=15,
|
||||||
alpha=2,
|
map_size=(800.0, 600.0),
|
||||||
|
alpha=2,
|
||||||
beta=1,
|
beta=1,
|
||||||
learning_rate=0.001,
|
learning_rate=0.001,
|
||||||
eps=0.1,
|
eps=0.1,
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
Args:
|
Args:
|
||||||
num_cities (int): 城市数目
|
num_cities (int): 城市数目
|
||||||
@ -26,7 +26,7 @@ class TSP(object):
|
|||||||
learning_rate (float): 学习率
|
learning_rate (float): 学习率
|
||||||
eps (float): 探索率,值越大,探索性越强,但越难收敛
|
eps (float): 探索率,值越大,探索性越强,但越难收敛
|
||||||
'''
|
'''
|
||||||
self.num_cities =num_cities
|
self.num_cities = num_cities
|
||||||
self.map_size = map_size
|
self.map_size = map_size
|
||||||
self.alpha = alpha
|
self.alpha = alpha
|
||||||
self.beta = beta
|
self.beta = beta
|
||||||
@ -40,7 +40,6 @@ class TSP(object):
|
|||||||
self.best_path = None
|
self.best_path = None
|
||||||
self.best_path_length = np.inf
|
self.best_path_length = np.inf
|
||||||
|
|
||||||
|
|
||||||
def generate_cities(self):
|
def generate_cities(self):
|
||||||
'''
|
'''
|
||||||
随机生成城市(坐标)
|
随机生成城市(坐标)
|
||||||
@ -77,7 +76,8 @@ class TSP(object):
|
|||||||
current_city_id = start_city_id
|
current_city_id = start_city_id
|
||||||
cities_visited.append(current_city_id)
|
cities_visited.append(current_city_id)
|
||||||
while len(cities_visited) < self.num_cities:
|
while len(cities_visited) < self.num_cities:
|
||||||
current_city_id, action_prob = self.choose_next_city(cities_visited)
|
current_city_id, action_prob = self.choose_next_city(
|
||||||
|
cities_visited)
|
||||||
cities_visited.append(current_city_id)
|
cities_visited.append(current_city_id)
|
||||||
action_probs.append(action_prob)
|
action_probs.append(action_prob)
|
||||||
cities_visited.append(cities_visited[0])
|
cities_visited.append(cities_visited[0])
|
||||||
@ -95,7 +95,7 @@ class TSP(object):
|
|||||||
根据策略选择下一个城市
|
根据策略选择下一个城市
|
||||||
'''
|
'''
|
||||||
current_city_id = cities_visited[-1]
|
current_city_id = cities_visited[-1]
|
||||||
|
|
||||||
# 对 quality 取指数,计算 softmax 概率用
|
# 对 quality 取指数,计算 softmax 概率用
|
||||||
probabilities = np.exp(self.qualities[current_city_id])
|
probabilities = np.exp(self.qualities[current_city_id])
|
||||||
|
|
||||||
@ -105,10 +105,11 @@ class TSP(object):
|
|||||||
|
|
||||||
# 计算 softmax 概率
|
# 计算 softmax 概率
|
||||||
probabilities = probabilities/probabilities.sum()
|
probabilities = probabilities/probabilities.sum()
|
||||||
|
|
||||||
if np.random.random() < self.eps:
|
if np.random.random() < self.eps:
|
||||||
# 以 eps 概率按softmax概率密度进行随机采样
|
# 以 eps 概率按softmax概率密度进行随机采样
|
||||||
next_city_id = np.random.choice(range(len(probabilities)), p=probabilities)
|
next_city_id = np.random.choice(
|
||||||
|
range(len(probabilities)), p=probabilities)
|
||||||
else:
|
else:
|
||||||
# 以 (1 - eps) 概率选择当前最优策略
|
# 以 (1 - eps) 概率选择当前最优策略
|
||||||
next_city_id = probabilities.argmax()
|
next_city_id = probabilities.argmax()
|
||||||
@ -118,7 +119,7 @@ class TSP(object):
|
|||||||
action_prob = probabilities[next_city_id]*self.eps + (1-self.eps)
|
action_prob = probabilities[next_city_id]*self.eps + (1-self.eps)
|
||||||
else:
|
else:
|
||||||
action_prob = probabilities[next_city_id]*self.eps
|
action_prob = probabilities[next_city_id]*self.eps
|
||||||
|
|
||||||
return next_city_id, action_prob
|
return next_city_id, action_prob
|
||||||
|
|
||||||
def calc_path_rewards(self, path, path_length):
|
def calc_path_rewards(self, path, path_length):
|
||||||
@ -146,7 +147,7 @@ class TSP(object):
|
|||||||
for fr, to in zip(path[:-1], path[1:]):
|
for fr, to in zip(path[:-1], path[1:]):
|
||||||
path_length += self.distances[fr, to]
|
path_length += self.distances[fr, to]
|
||||||
return path_length
|
return path_length
|
||||||
|
|
||||||
def calc_updates_for_one_rollout(self, path, action_probs, rewards):
|
def calc_updates_for_one_rollout(self, path, action_probs, rewards):
|
||||||
'''
|
'''
|
||||||
对于给定的一次 rollout 的结果,计算其对应的 qualities 和 normalizers
|
对于给定的一次 rollout 的结果,计算其对应的 qualities 和 normalizers
|
||||||
@ -165,16 +166,19 @@ class TSP(object):
|
|||||||
'''
|
'''
|
||||||
lr = self.learning_rate
|
lr = self.learning_rate
|
||||||
for fr, to, new_quality, new_normalizer in zip(
|
for fr, to, new_quality, new_normalizer in zip(
|
||||||
path[:-1], path[1:], new_qualities, new_normalizers):
|
path[:-1], path[1:], new_qualities, new_normalizers):
|
||||||
self.normalizers[fr] = (1-lr)*self.normalizers[fr] + lr*new_normalizer
|
self.normalizers[fr] = (
|
||||||
self.qualities[fr, to] = (1-lr)*self.qualities[fr, to] + lr*new_quality
|
1-lr)*self.normalizers[fr] + lr*new_normalizer
|
||||||
|
self.qualities[fr, to] = (
|
||||||
|
1-lr)*self.qualities[fr, to] + lr*new_quality
|
||||||
|
|
||||||
async def train_for_one_rollout(self, start_city_id):
|
async def train_for_one_rollout(self, start_city_id):
|
||||||
'''
|
'''
|
||||||
对一次 rollout 的结果进行训练的流程
|
对一次 rollout 的结果进行训练的流程
|
||||||
'''
|
'''
|
||||||
path, action_probs, rewards = self.rollout(start_city_id=start_city_id)
|
path, action_probs, rewards = self.rollout(start_city_id=start_city_id)
|
||||||
new_qualities, new_normalizers = self.calc_updates_for_one_rollout(path, action_probs, rewards)
|
new_qualities, new_normalizers = self.calc_updates_for_one_rollout(
|
||||||
|
path, action_probs, rewards)
|
||||||
self.update(path, new_qualities, new_normalizers)
|
self.update(path, new_qualities, new_normalizers)
|
||||||
|
|
||||||
async def train_for_one_epoch(self):
|
async def train_for_one_epoch(self):
|
||||||
@ -182,7 +186,8 @@ class TSP(object):
|
|||||||
对一个 epoch 的结果进行训练的流程,
|
对一个 epoch 的结果进行训练的流程,
|
||||||
一个 epoch 对应于从每个 city 出发进行一次 rollout
|
一个 epoch 对应于从每个 city 出发进行一次 rollout
|
||||||
'''
|
'''
|
||||||
tasks = [self.train_for_one_rollout(start_city_id) for start_city_id in range(self.num_cities)]
|
tasks = [self.train_for_one_rollout(
|
||||||
|
start_city_id) for start_city_id in range(self.num_cities)]
|
||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
async def train(self, num_epochs=1000, display=True):
|
async def train(self, num_epochs=1000, display=True):
|
||||||
@ -203,14 +208,30 @@ class TSP(object):
|
|||||||
x1, y1 = self.cities[:, fr]
|
x1, y1 = self.cities[:, fr]
|
||||||
x2, y2 = self.cities[:, to]
|
x2, y2 = self.cities[:, to]
|
||||||
dx, dy = x2-x1, y2-y1
|
dx, dy = x2-x1, y2-y1
|
||||||
plt.arrow(x1, y1, dx, dy, width=0.01*min(self.map_size),
|
plt.arrow(x1, y1, dx, dy, width=0.01*min(self.map_size),
|
||||||
edgecolor='orange', facecolor='white', animated=True,
|
edgecolor='orange', facecolor='white', animated=True,
|
||||||
length_includes_head=True)
|
length_includes_head=True)
|
||||||
nrs = np.exp(self.qualities)
|
nrs = np.exp(self.qualities)
|
||||||
for i in range(self.num_cities):
|
for i in range(self.num_cities):
|
||||||
nrs[i, i] = 0
|
nrs[i, i] = 0
|
||||||
gap = np.abs(np.exp(self.normalizers) - nrs.sum(-1)).mean()
|
gap = np.abs(np.exp(self.normalizers) - nrs.sum(-1)).mean()
|
||||||
plt.title(f'epoch {epoch}: path length = {self.best_path_length:.2f}, normalizer error = {gap:.3f}')
|
plt.title(
|
||||||
|
f'epoch {epoch}: path length = {self.best_path_length:.2f}, normalizer error = {gap:.3f}')
|
||||||
plt.savefig('tsp.png')
|
plt.savefig('tsp.png')
|
||||||
plt.show()
|
plt.show()
|
||||||
clear_output(wait=True)
|
|
||||||
|
async def main():
|
||||||
|
# 创建TSP实例
|
||||||
|
tsp = TSP()
|
||||||
|
|
||||||
|
# 训练模型
|
||||||
|
await tsp.train(200, display=False)
|
||||||
|
|
||||||
|
# 输出最终路径
|
||||||
|
print(f"最优路径: {tsp.best_path}")
|
||||||
|
print(f"路径长度: {tsp.best_path_length:.2f}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# 使用asyncio.run()运行异步主函数
|
||||||
|
asyncio.run(main())
|
||||||
|
278
Q_learning/mTSP.py
Normal file
278
Q_learning/mTSP.py
Normal file
@ -0,0 +1,278 @@
|
|||||||
|
import pylab as plt
|
||||||
|
import numpy as np
|
||||||
|
import asyncio
|
||||||
|
from typing import List, Tuple, Dict, Any
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
class mTSP:
|
||||||
|
'''
|
||||||
|
用 Q-Learning 求解多旅行商问题
|
||||||
|
基于TSP.py修改,增加了多旅行商的支持
|
||||||
|
'''
|
||||||
|
def __init__(self,
|
||||||
|
num_cities: int = 15,
|
||||||
|
num_drones: int = 3,
|
||||||
|
map_size: Tuple[float, float] = (800.0, 600.0),
|
||||||
|
alpha: float = 2,
|
||||||
|
beta: float = 1,
|
||||||
|
learning_rate: float = 0.001,
|
||||||
|
eps: float = 0.1,
|
||||||
|
params_file: str = 'params2.yml'):
|
||||||
|
'''
|
||||||
|
Args:
|
||||||
|
num_cities (int): 实际城市数目(不包括虚拟起点)
|
||||||
|
num_drones (int): 无人机数量
|
||||||
|
map_size (int, int): 地图尺寸(宽,高)
|
||||||
|
alpha (float): 一个超参,值越大,越优先探索最近的点
|
||||||
|
beta (float): 一个超参,值越大,越优先探索可能导向总距离最优的点
|
||||||
|
learning_rate (float): 学习率
|
||||||
|
eps (float): 探索率,值越大,探索性越强,但越难收敛
|
||||||
|
params_file (str): 参数文件路径
|
||||||
|
'''
|
||||||
|
self.num_cities = num_cities
|
||||||
|
self.num_drones = num_drones
|
||||||
|
self.map_size = map_size
|
||||||
|
self.alpha = alpha
|
||||||
|
self.beta = beta
|
||||||
|
self.eps = eps
|
||||||
|
self.learning_rate = learning_rate
|
||||||
|
|
||||||
|
# 加载参数
|
||||||
|
with open(params_file, 'r', encoding='utf-8') as file:
|
||||||
|
self.params = yaml.safe_load(file)
|
||||||
|
|
||||||
|
# 生成城市和虚拟起点
|
||||||
|
self.cities = self.generate_cities()
|
||||||
|
self.to_process_idx = self.generate_start_points()
|
||||||
|
|
||||||
|
# 计算距离矩阵
|
||||||
|
self.distances = self.get_dist_matrix()
|
||||||
|
self.mean_distance = self.distances.mean()
|
||||||
|
|
||||||
|
# Q-learning相关
|
||||||
|
self.qualities = np.zeros([self.total_cities, self.total_cities])
|
||||||
|
self.normalizers = np.zeros(self.total_cities)
|
||||||
|
self.best_path = None
|
||||||
|
self.best_path_length = np.inf
|
||||||
|
|
||||||
|
# 计算每个点的飞行时间和基站时间
|
||||||
|
self.rectangles = self.calculate_rectangles()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total_cities(self) -> int:
|
||||||
|
"""总城市数(包括虚拟起点)"""
|
||||||
|
return self.num_cities + self.num_drones - 1
|
||||||
|
|
||||||
|
def generate_cities(self) -> np.ndarray:
|
||||||
|
'''生成城市坐标'''
|
||||||
|
max_width, max_height = self.map_size
|
||||||
|
cities = np.random.random([2, self.num_cities]) \
|
||||||
|
* np.array([max_width, max_height]).reshape(2, -1)
|
||||||
|
return cities
|
||||||
|
|
||||||
|
def generate_start_points(self) -> List[int]:
|
||||||
|
'''生成起点索引列表'''
|
||||||
|
# 添加虚拟起点
|
||||||
|
virtual_starts = np.zeros([2, self.num_drones - 1])
|
||||||
|
self.cities = np.hstack([self.cities, virtual_starts])
|
||||||
|
return list(range(self.num_cities, self.total_cities))
|
||||||
|
|
||||||
|
def get_dist_matrix(self) -> np.ndarray:
|
||||||
|
'''计算距离矩阵'''
|
||||||
|
dist_matrix = np.zeros([self.total_cities, self.total_cities])
|
||||||
|
for i in range(self.total_cities):
|
||||||
|
for j in range(self.total_cities):
|
||||||
|
if i == j:
|
||||||
|
dist_matrix[i, j] = np.inf
|
||||||
|
continue
|
||||||
|
xi, xj = self.cities[0, i], self.cities[0, j]
|
||||||
|
yi, yj = self.cities[1, i], self.cities[1, j]
|
||||||
|
dist_matrix[i, j] = np.sqrt((xi-xj)**2 + (yi-yj)**2)
|
||||||
|
|
||||||
|
# 设置起点之间的距离为无穷大
|
||||||
|
for i in self.to_process_idx:
|
||||||
|
for j in self.to_process_idx:
|
||||||
|
if i != j:
|
||||||
|
dist_matrix[i, j] = np.inf
|
||||||
|
|
||||||
|
return dist_matrix
|
||||||
|
|
||||||
|
def calculate_rectangles(self) -> List[Dict[str, Any]]:
|
||||||
|
'''计算每个点的飞行时间和基站时间'''
|
||||||
|
rectangles = []
|
||||||
|
for i in range(self.num_cities):
|
||||||
|
d = 1.0 # 这里简化处理,实际应该根据区域大小计算
|
||||||
|
rho_time_limit = (self.params['flight_time_factor'] - self.params['trans_time_factor']) / \
|
||||||
|
(self.params['comp_time_factor'] - self.params['trans_time_factor'])
|
||||||
|
rho_energy_limit = (self.params['battery_energy_capacity'] -
|
||||||
|
self.params['flight_energy_factor'] * d -
|
||||||
|
self.params['trans_energy_factor'] * d) / \
|
||||||
|
(self.params['comp_energy_factor'] * d -
|
||||||
|
self.params['trans_energy_factor'] * d)
|
||||||
|
rho = min(rho_time_limit, rho_energy_limit)
|
||||||
|
|
||||||
|
flight_time = self.params['flight_time_factor'] * d
|
||||||
|
bs_time = self.params['bs_time_factor'] * (1 - rho) * d
|
||||||
|
|
||||||
|
rectangles.append({
|
||||||
|
'flight_time': flight_time,
|
||||||
|
'bs_time': bs_time,
|
||||||
|
'center': (self.cities[0, i], self.cities[1, i])
|
||||||
|
})
|
||||||
|
return rectangles
|
||||||
|
|
||||||
|
def rollout(self, start_city_id: int = None) -> Tuple[List[int], List[float], List[float]]:
|
||||||
|
'''执行一次路径探索'''
|
||||||
|
cities_visited = []
|
||||||
|
action_probs = []
|
||||||
|
|
||||||
|
if start_city_id is None:
|
||||||
|
start_city_id = np.random.choice(self.to_process_idx)
|
||||||
|
current_city_id = start_city_id
|
||||||
|
cities_visited.append(current_city_id)
|
||||||
|
|
||||||
|
while len(cities_visited) < self.total_cities:
|
||||||
|
current_city_id, action_prob = self.choose_next_city(cities_visited)
|
||||||
|
cities_visited.append(current_city_id)
|
||||||
|
action_probs.append(action_prob)
|
||||||
|
|
||||||
|
# 返回起点
|
||||||
|
cities_visited.append(cities_visited[0])
|
||||||
|
action_probs.append(1.0)
|
||||||
|
|
||||||
|
path_length = self.calc_path_length(cities_visited)
|
||||||
|
if path_length < self.best_path_length:
|
||||||
|
self.best_path = cities_visited
|
||||||
|
self.best_path_length = path_length
|
||||||
|
|
||||||
|
rewards = self.calc_path_rewards(cities_visited, path_length)
|
||||||
|
return cities_visited, action_probs, rewards
|
||||||
|
|
||||||
|
def choose_next_city(self, cities_visited: List[int]) -> Tuple[int, float]:
|
||||||
|
'''选择下一个城市'''
|
||||||
|
current_city_id = cities_visited[-1]
|
||||||
|
probabilities = np.exp(self.qualities[current_city_id])
|
||||||
|
|
||||||
|
# 将已访问的城市概率设为0
|
||||||
|
for city_visited in cities_visited:
|
||||||
|
probabilities[city_visited] = 0
|
||||||
|
|
||||||
|
# 计算softmax概率
|
||||||
|
probabilities = probabilities/probabilities.sum()
|
||||||
|
|
||||||
|
if np.random.random() < self.eps:
|
||||||
|
next_city_id = np.random.choice(range(len(probabilities)), p=probabilities)
|
||||||
|
else:
|
||||||
|
next_city_id = probabilities.argmax()
|
||||||
|
|
||||||
|
if probabilities.argmax() == next_city_id:
|
||||||
|
action_prob = probabilities[next_city_id]*self.eps + (1-self.eps)
|
||||||
|
else:
|
||||||
|
action_prob = probabilities[next_city_id]*self.eps
|
||||||
|
|
||||||
|
return next_city_id, action_prob
|
||||||
|
|
||||||
|
def calc_path_length(self, path: List[int]) -> float:
|
||||||
|
'''计算路径长度'''
|
||||||
|
# 将路径分成多个子路径
|
||||||
|
car_paths = []
|
||||||
|
found_start_points = []
|
||||||
|
|
||||||
|
# 找到所有起点
|
||||||
|
for i, city in enumerate(path):
|
||||||
|
if city in self.to_process_idx:
|
||||||
|
found_start_points.append(i)
|
||||||
|
|
||||||
|
# 根据起点分割路径
|
||||||
|
for i in range(len(found_start_points)-1):
|
||||||
|
start_idx = found_start_points[i]
|
||||||
|
end_idx = found_start_points[i+1]
|
||||||
|
car_paths.append(path[start_idx:end_idx+1])
|
||||||
|
|
||||||
|
# 计算每个子路径的时间
|
||||||
|
T_k_list = []
|
||||||
|
for car_path in car_paths:
|
||||||
|
flight_time = 0
|
||||||
|
bs_time = 0
|
||||||
|
car_time = 0
|
||||||
|
|
||||||
|
# 计算飞行时间和基站时间
|
||||||
|
for point in car_path:
|
||||||
|
if point not in self.to_process_idx:
|
||||||
|
flight_time += self.rectangles[point]['flight_time']
|
||||||
|
bs_time += self.rectangles[point]['bs_time']
|
||||||
|
|
||||||
|
# 计算车辆时间
|
||||||
|
for i in range(len(car_path)-1):
|
||||||
|
if car_path[i] not in self.to_process_idx and car_path[i+1] not in self.to_process_idx:
|
||||||
|
car_time += self.distances[car_path[i], car_path[i+1]] * self.params['car_time_factor']
|
||||||
|
|
||||||
|
# 计算总时间
|
||||||
|
system_time = max(flight_time + car_time, bs_time)
|
||||||
|
T_k_list.append(system_time)
|
||||||
|
|
||||||
|
return max(T_k_list)
|
||||||
|
|
||||||
|
def calc_path_rewards(self, path: List[int], path_length: float) -> List[float]:
|
||||||
|
'''计算路径奖励'''
|
||||||
|
rewards = []
|
||||||
|
for fr, to in zip(path[:-1], path[1:]):
|
||||||
|
dist = self.distances[fr, to]
|
||||||
|
reward = (self.mean_distance/path_length)**self.beta
|
||||||
|
reward = reward*(self.mean_distance/dist)**self.alpha
|
||||||
|
rewards.append(np.log(reward))
|
||||||
|
return rewards
|
||||||
|
|
||||||
|
def calc_updates_for_one_rollout(self, path: List[int], action_probs: List[float],
|
||||||
|
rewards: List[float]) -> Tuple[List[float], List[float]]:
|
||||||
|
'''计算一次rollout的更新值'''
|
||||||
|
qualities = []
|
||||||
|
normalizers = []
|
||||||
|
for fr, to, reward, action_prob in zip(path[:-1], path[1:], rewards, action_probs):
|
||||||
|
log_action_probability = np.log(action_prob)
|
||||||
|
qualities.append(- reward*log_action_probability)
|
||||||
|
normalizers.append(- (reward + 1)*log_action_probability)
|
||||||
|
return qualities, normalizers
|
||||||
|
|
||||||
|
def update(self, path: List[int], new_qualities: List[float],
|
||||||
|
new_normalizers: List[float]) -> None:
|
||||||
|
'''更新Q值和normalizer'''
|
||||||
|
lr = self.learning_rate
|
||||||
|
for fr, to, new_quality, new_normalizer in zip(
|
||||||
|
path[:-1], path[1:], new_qualities, new_normalizers):
|
||||||
|
self.normalizers[fr] = (1-lr)*self.normalizers[fr] + lr*new_normalizer
|
||||||
|
self.qualities[fr, to] = (1-lr)*self.qualities[fr, to] + lr*new_quality
|
||||||
|
|
||||||
|
async def train_for_one_rollout(self, start_city_id: int) -> None:
|
||||||
|
'''训练一次rollout'''
|
||||||
|
path, action_probs, rewards = self.rollout(start_city_id)
|
||||||
|
new_qualities, new_normalizers = self.calc_updates_for_one_rollout(
|
||||||
|
path, action_probs, rewards)
|
||||||
|
self.update(path, new_qualities, new_normalizers)
|
||||||
|
|
||||||
|
async def train_for_one_epoch(self) -> None:
|
||||||
|
'''训练一个epoch'''
|
||||||
|
tasks = [self.train_for_one_rollout(start_city_id)
|
||||||
|
for start_city_id in self.to_process_idx]
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
async def train(self, num_epochs: int = 1000, display: bool = True) -> None:
|
||||||
|
'''训练过程'''
|
||||||
|
for epoch in range(num_epochs):
|
||||||
|
await self.train_for_one_epoch()
|
||||||
|
if display and epoch % 100 == 0:
|
||||||
|
print(f"Epoch {epoch}: Best path length = {self.best_path_length:.2f}")
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
# 创建mTSP实例
|
||||||
|
mtsp = mTSP(num_cities=12, num_drones=3)
|
||||||
|
|
||||||
|
# 训练模型
|
||||||
|
await mtsp.train(200, display=True)
|
||||||
|
|
||||||
|
# 输出最终路径
|
||||||
|
print(f"\n最优路径: {mtsp.best_path}")
|
||||||
|
print(f"路径长度: {mtsp.best_path_length:.2f}")
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
asyncio.run(main())
|
277
Q_learning/test.ipynb
Normal file
277
Q_learning/test.ipynb
Normal file
File diff suppressed because one or more lines are too long
BIN
Q_learning/tsp.png
Normal file
BIN
Q_learning/tsp.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 45 KiB |
Loading…
Reference in New Issue
Block a user