HPCC2025/Q_learning/TSP_origin.py

217 lines
7.7 KiB
Python
Raw Normal View History

2025-03-27 21:48:07 +08:00
import pylab as plt
import numpy as np
import asyncio
2025-03-28 10:53:41 +08:00
2025-03-27 21:48:07 +08:00
class TSP(object):
'''
Q-Learning 求解 TSP 问题
作者 Surfer Zen @ https://www.zhihu.com/people/surfer-zen
'''
2025-03-28 10:53:41 +08:00
def __init__(self,
num_cities=15,
map_size=(800.0, 600.0),
alpha=2,
2025-03-27 21:48:07 +08:00
beta=1,
learning_rate=0.001,
eps=0.1,
2025-03-28 10:53:41 +08:00
):
2025-03-27 21:48:07 +08:00
'''
Args:
num_cities (int): 城市数目
map_size (int, int): 地图尺寸
alpha (float): 一个超参值越大越优先探索最近的点
beta (float): 一个超参值越大越优先探索可能导向总距离最优的点
learning_rate (float) 学习率
eps (float): 探索率值越大探索性越强但越难收敛
'''
2025-03-28 10:53:41 +08:00
self.num_cities = num_cities
2025-03-27 21:48:07 +08:00
self.map_size = map_size
self.alpha = alpha
self.beta = beta
self.eps = eps
self.learning_rate = learning_rate
self.cities = self.generate_cities()
self.distances = self.get_dist_matrix()
self.mean_distance = self.distances.mean()
self.qualities = np.zeros([num_cities, num_cities])
self.normalizers = np.zeros(num_cities)
self.best_path = None
self.best_path_length = np.inf
def generate_cities(self):
'''
随机生成城市坐标
Returns:
cities: [[x1, x2, x3...], [y1, y2, y3...]] 城市坐标
'''
max_width, max_height = self.map_size
cities = np.random.random([2, self.num_cities]) \
* np.array([max_width, max_height]).reshape(2, -1)
return cities
def get_dist_matrix(self):
'''
根据城市坐标计算距离矩阵
'''
dist_matrix = np.zeros([self.num_cities, self.num_cities])
for i in range(self.num_cities):
for j in range(self.num_cities):
if i == j:
continue
xi, xj = self.cities[0, i], self.cities[0, j]
yi, yj = self.cities[1, i], self.cities[1, j]
dist_matrix[i, j] = np.sqrt((xi-xj)**2 + (yi-yj)**2)
return dist_matrix
def rollout(self, start_city_id=None):
'''
start_city 出发根据策略在城市间游走直到所有城市都走了一遍
'''
cities_visited = []
action_probs = []
if start_city_id is None:
start_city_id = np.random.randint(self.num_cities)
current_city_id = start_city_id
cities_visited.append(current_city_id)
while len(cities_visited) < self.num_cities:
2025-03-28 10:53:41 +08:00
current_city_id, action_prob = self.choose_next_city(
cities_visited)
2025-03-27 21:48:07 +08:00
cities_visited.append(current_city_id)
action_probs.append(action_prob)
cities_visited.append(cities_visited[0])
action_probs.append(1.0)
path_length = self.calc_path_length(cities_visited)
if path_length < self.best_path_length:
self.best_path = cities_visited
self.best_path_length = path_length
rewards = self.calc_path_rewards(cities_visited, path_length)
return cities_visited, action_probs, rewards
def choose_next_city(self, cities_visited):
'''
根据策略选择下一个城市
'''
current_city_id = cities_visited[-1]
2025-03-28 10:53:41 +08:00
2025-03-27 21:48:07 +08:00
# 对 quality 取指数,计算 softmax 概率用
probabilities = np.exp(self.qualities[current_city_id])
2025-03-28 19:57:44 +08:00
print(probabilities)
2025-03-27 21:48:07 +08:00
# 将已经走过的城市概率设置为零
for city_visited in cities_visited:
probabilities[city_visited] = 0
# 计算 softmax 概率
probabilities = probabilities/probabilities.sum()
2025-03-28 10:53:41 +08:00
2025-03-27 21:48:07 +08:00
if np.random.random() < self.eps:
# 以 eps 概率按softmax概率密度进行随机采样
2025-03-28 10:53:41 +08:00
next_city_id = np.random.choice(
range(len(probabilities)), p=probabilities)
2025-03-27 21:48:07 +08:00
else:
# 以 (1 - eps) 概率选择当前最优策略
next_city_id = probabilities.argmax()
# 计算当前决策/action 的概率
if probabilities.argmax() == next_city_id:
action_prob = probabilities[next_city_id]*self.eps + (1-self.eps)
else:
action_prob = probabilities[next_city_id]*self.eps
2025-03-28 10:53:41 +08:00
2025-03-27 21:48:07 +08:00
return next_city_id, action_prob
def calc_path_rewards(self, path, path_length):
'''
计算给定路径的奖励/rewards
Args:
path (list[int]): 路径每个元素代表城市的 id
path_length (float): 路径长路
Returns:
rewards: 每一步的奖励总距离以及当前这一步的距离越大奖励越小
'''
rewards = []
for fr, to in zip(path[:-1], path[1:]):
dist = self.distances[fr, to]
reward = (self.mean_distance/path_length)**self.beta
reward = reward*(self.mean_distance/dist)**self.alpha
rewards.append(np.log(reward))
return rewards
def calc_path_length(self, path):
'''
计算路径长度
'''
path_length = 0
for fr, to in zip(path[:-1], path[1:]):
path_length += self.distances[fr, to]
return path_length
2025-03-28 10:53:41 +08:00
2025-03-27 21:48:07 +08:00
def calc_updates_for_one_rollout(self, path, action_probs, rewards):
'''
对于给定的一次 rollout 的结果计算其对应的 qualities normalizers
'''
qualities = []
normalizers = []
for fr, to, reward, action_prob in zip(path[:-1], path[1:], rewards, action_probs):
log_action_probability = np.log(action_prob)
qualities.append(- reward*log_action_probability)
normalizers.append(- (reward + 1)*log_action_probability)
return qualities, normalizers
def update(self, path, new_qualities, new_normalizers):
'''
用渐近平均的思想 qualities normalizers 进行更新
'''
lr = self.learning_rate
for fr, to, new_quality, new_normalizer in zip(
2025-03-28 10:53:41 +08:00
path[:-1], path[1:], new_qualities, new_normalizers):
self.normalizers[fr] = (
1-lr)*self.normalizers[fr] + lr*new_normalizer
self.qualities[fr, to] = (
1-lr)*self.qualities[fr, to] + lr*new_quality
2025-03-27 21:48:07 +08:00
async def train_for_one_rollout(self, start_city_id):
'''
对一次 rollout 的结果进行训练的流程
'''
path, action_probs, rewards = self.rollout(start_city_id=start_city_id)
2025-03-28 10:53:41 +08:00
new_qualities, new_normalizers = self.calc_updates_for_one_rollout(
path, action_probs, rewards)
2025-03-27 21:48:07 +08:00
self.update(path, new_qualities, new_normalizers)
async def train_for_one_epoch(self):
'''
对一个 epoch 的结果进行训练的流程
一个 epoch 对应于从每个 city 出发进行一次 rollout
'''
2025-03-28 10:53:41 +08:00
tasks = [self.train_for_one_rollout(
start_city_id) for start_city_id in range(self.num_cities)]
2025-03-27 21:48:07 +08:00
await asyncio.gather(*tasks)
async def train(self, num_epochs=1000, display=True):
'''
总训练流程
'''
for epoch in range(num_epochs):
await self.train_for_one_epoch()
2025-03-28 15:13:23 +08:00
2025-03-28 10:53:41 +08:00
async def main():
# 创建TSP实例
tsp = TSP()
# 训练模型
await tsp.train(200, display=False)
# 输出最终路径
print(f"最优路径: {tsp.best_path}")
print(f"路径长度: {tsp.best_path_length:.2f}")
2025-03-28 15:13:23 +08:00
if __name__ == "__main__":
2025-03-28 19:57:44 +08:00
np.random.seed(42)
2025-03-28 10:53:41 +08:00
# 使用asyncio.run()运行异步主函数
asyncio.run(main())