278 lines
69 KiB
Plaintext
278 lines
69 KiB
Plaintext
![]() |
{
|
|||
|
"cells": [
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 2,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"%matplotlib inline\n",
|
|||
|
"import pylab as plt\n",
|
|||
|
"from IPython.display import clear_output\n",
|
|||
|
"import numpy as np\n",
|
|||
|
"import asyncio\n",
|
|||
|
"\n",
|
|||
|
"class TSP(object):\n",
|
|||
|
" '''\n",
|
|||
|
" 用 Q-Learning 求解 TSP 问题\n",
|
|||
|
" 作者 Surfer Zen @ https://www.zhihu.com/people/surfer-zen\n",
|
|||
|
" '''\n",
|
|||
|
" def __init__(self, \n",
|
|||
|
" num_cities=15, \n",
|
|||
|
" map_size=(800.0, 600.0), \n",
|
|||
|
" alpha=2, \n",
|
|||
|
" beta=1,\n",
|
|||
|
" learning_rate=0.001,\n",
|
|||
|
" eps=0.1,\n",
|
|||
|
" ):\n",
|
|||
|
" '''\n",
|
|||
|
" Args:\n",
|
|||
|
" num_cities (int): 城市数目\n",
|
|||
|
" map_size (int, int): 地图尺寸(宽,高)\n",
|
|||
|
" alpha (float): 一个超参,值越大,越优先探索最近的点\n",
|
|||
|
" beta (float): 一个超参,值越大,越优先探索可能导向总距离最优的点\n",
|
|||
|
" learning_rate (float): 学习率\n",
|
|||
|
" eps (float): 探索率,值越大,探索性越强,但越难收敛 \n",
|
|||
|
" '''\n",
|
|||
|
" self.num_cities =num_cities\n",
|
|||
|
" self.map_size = map_size\n",
|
|||
|
" self.alpha = alpha\n",
|
|||
|
" self.beta = beta\n",
|
|||
|
" self.eps = eps\n",
|
|||
|
" self.learning_rate = learning_rate\n",
|
|||
|
" self.cities = self.generate_cities()\n",
|
|||
|
" self.distances = self.get_dist_matrix()\n",
|
|||
|
" self.mean_distance = self.distances.mean()\n",
|
|||
|
" self.qualities = np.zeros([num_cities, num_cities])\n",
|
|||
|
" self.normalizers = np.zeros(num_cities)\n",
|
|||
|
" self.best_path = None\n",
|
|||
|
" self.best_path_length = np.inf\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
" def generate_cities(self):\n",
|
|||
|
" '''\n",
|
|||
|
" 随机生成城市(坐标)\n",
|
|||
|
" Returns:\n",
|
|||
|
" cities: [[x1, x2, x3...], [y1, y2, y3...]] 城市坐标\n",
|
|||
|
" '''\n",
|
|||
|
" max_width, max_height = self.map_size\n",
|
|||
|
" cities = np.random.random([2, self.num_cities]) \\\n",
|
|||
|
" * np.array([max_width, max_height]).reshape(2, -1)\n",
|
|||
|
" return cities\n",
|
|||
|
"\n",
|
|||
|
" def get_dist_matrix(self):\n",
|
|||
|
" '''\n",
|
|||
|
" 根据城市坐标,计算距离矩阵\n",
|
|||
|
" '''\n",
|
|||
|
" dist_matrix = np.zeros([self.num_cities, self.num_cities])\n",
|
|||
|
" for i in range(self.num_cities):\n",
|
|||
|
" for j in range(self.num_cities):\n",
|
|||
|
" if i == j:\n",
|
|||
|
" continue\n",
|
|||
|
" xi, xj = self.cities[0, i], self.cities[0, j]\n",
|
|||
|
" yi, yj = self.cities[1, i], self.cities[1, j]\n",
|
|||
|
" dist_matrix[i, j] = np.sqrt((xi-xj)**2 + (yi-yj)**2)\n",
|
|||
|
" return dist_matrix\n",
|
|||
|
"\n",
|
|||
|
" def rollout(self, start_city_id=None):\n",
|
|||
|
" '''\n",
|
|||
|
" 从 start_city 出发,根据策略,在城市间游走,直到所有城市都走了一遍\n",
|
|||
|
" '''\n",
|
|||
|
" cities_visited = []\n",
|
|||
|
" action_probs = []\n",
|
|||
|
" if start_city_id is None:\n",
|
|||
|
" start_city_id = np.random.randint(self.num_cities)\n",
|
|||
|
" current_city_id = start_city_id\n",
|
|||
|
" cities_visited.append(current_city_id)\n",
|
|||
|
" while len(cities_visited) < self.num_cities:\n",
|
|||
|
" current_city_id, action_prob = self.choose_next_city(cities_visited)\n",
|
|||
|
" cities_visited.append(current_city_id)\n",
|
|||
|
" action_probs.append(action_prob)\n",
|
|||
|
" cities_visited.append(cities_visited[0])\n",
|
|||
|
" action_probs.append(1.0)\n",
|
|||
|
"\n",
|
|||
|
" path_length = self.calc_path_length(cities_visited)\n",
|
|||
|
" if path_length < self.best_path_length:\n",
|
|||
|
" self.best_path = cities_visited\n",
|
|||
|
" self.best_path_length = path_length\n",
|
|||
|
" rewards = self.calc_path_rewards(cities_visited, path_length)\n",
|
|||
|
" return cities_visited, action_probs, rewards\n",
|
|||
|
"\n",
|
|||
|
" def choose_next_city(self, cities_visited):\n",
|
|||
|
" '''\n",
|
|||
|
" 根据策略选择下一个城市\n",
|
|||
|
" '''\n",
|
|||
|
" current_city_id = cities_visited[-1]\n",
|
|||
|
" \n",
|
|||
|
" # 对 quality 取指数,计算 softmax 概率用\n",
|
|||
|
" probabilities = np.exp(self.qualities[current_city_id])\n",
|
|||
|
"\n",
|
|||
|
" # 将已经走过的城市概率设置为零\n",
|
|||
|
" for city_visited in cities_visited:\n",
|
|||
|
" probabilities[city_visited] = 0\n",
|
|||
|
"\n",
|
|||
|
" # 计算 softmax 概率\n",
|
|||
|
" probabilities = probabilities/probabilities.sum()\n",
|
|||
|
" \n",
|
|||
|
" if np.random.random() < self.eps:\n",
|
|||
|
" # 以 eps 概率按softmax概率密度进行随机采样\n",
|
|||
|
" next_city_id = np.random.choice(range(len(probabilities)), p=probabilities)\n",
|
|||
|
" else:\n",
|
|||
|
" # 以 (1 - eps) 概率选择当前最优策略\n",
|
|||
|
" next_city_id = probabilities.argmax()\n",
|
|||
|
"\n",
|
|||
|
" # 计算当前决策/action 的概率\n",
|
|||
|
" if probabilities.argmax() == next_city_id:\n",
|
|||
|
" action_prob = probabilities[next_city_id]*self.eps + (1-self.eps)\n",
|
|||
|
" else:\n",
|
|||
|
" action_prob = probabilities[next_city_id]*self.eps\n",
|
|||
|
" \n",
|
|||
|
" return next_city_id, action_prob\n",
|
|||
|
"\n",
|
|||
|
" def calc_path_rewards(self, path, path_length):\n",
|
|||
|
" '''\n",
|
|||
|
" 计算给定路径的奖励/rewards\n",
|
|||
|
" Args:\n",
|
|||
|
" path (list[int]): 路径,每个元素代表城市的 id\n",
|
|||
|
" path_length (float): 路径长路\n",
|
|||
|
" Returns:\n",
|
|||
|
" rewards: 每一步的奖励,总距离以及当前这一步的距离越大,奖励越小\n",
|
|||
|
" '''\n",
|
|||
|
" rewards = []\n",
|
|||
|
" for fr, to in zip(path[:-1], path[1:]):\n",
|
|||
|
" dist = self.distances[fr, to]\n",
|
|||
|
" reward = (self.mean_distance/path_length)**self.beta\n",
|
|||
|
" reward = reward*(self.mean_distance/dist)**self.alpha\n",
|
|||
|
" rewards.append(np.log(reward))\n",
|
|||
|
" return rewards\n",
|
|||
|
"\n",
|
|||
|
" def calc_path_length(self, path):\n",
|
|||
|
" '''\n",
|
|||
|
" 计算路径长度\n",
|
|||
|
" '''\n",
|
|||
|
" path_length = 0\n",
|
|||
|
" for fr, to in zip(path[:-1], path[1:]):\n",
|
|||
|
" path_length += self.distances[fr, to]\n",
|
|||
|
" return path_length\n",
|
|||
|
" \n",
|
|||
|
" def calc_updates_for_one_rollout(self, path, action_probs, rewards):\n",
|
|||
|
" '''\n",
|
|||
|
" 对于给定的一次 rollout 的结果,计算其对应的 qualities 和 normalizers \n",
|
|||
|
" '''\n",
|
|||
|
" qualities = []\n",
|
|||
|
" normalizers = []\n",
|
|||
|
" for fr, to, reward, action_prob in zip(path[:-1], path[1:], rewards, action_probs):\n",
|
|||
|
" log_action_probability = np.log(action_prob)\n",
|
|||
|
" qualities.append(- reward*log_action_probability)\n",
|
|||
|
" normalizers.append(- (reward + 1)*log_action_probability)\n",
|
|||
|
" return qualities, normalizers\n",
|
|||
|
"\n",
|
|||
|
" def update(self, path, new_qualities, new_normalizers):\n",
|
|||
|
" '''\n",
|
|||
|
" 用渐近平均的思想,对 qualities 和 normalizers 进行更新\n",
|
|||
|
" '''\n",
|
|||
|
" lr = self.learning_rate\n",
|
|||
|
" for fr, to, new_quality, new_normalizer in zip(\n",
|
|||
|
" path[:-1], path[1:], new_qualities, new_normalizers):\n",
|
|||
|
" self.normalizers[fr] = (1-lr)*self.normalizers[fr] + lr*new_normalizer\n",
|
|||
|
" self.qualities[fr, to] = (1-lr)*self.qualities[fr, to] + lr*new_quality\n",
|
|||
|
" \n",
|
|||
|
" async def train_for_one_rollout(self, start_city_id):\n",
|
|||
|
" '''\n",
|
|||
|
" 对一次 rollout 的结果进行训练的流程\n",
|
|||
|
" '''\n",
|
|||
|
" path, action_probs, rewards = self.rollout(start_city_id=start_city_id)\n",
|
|||
|
" new_qualities, new_normalizers = self.calc_updates_for_one_rollout(path, action_probs, rewards)\n",
|
|||
|
" self.update(path, new_qualities, new_normalizers)\n",
|
|||
|
"\n",
|
|||
|
" async def train_for_one_epoch(self):\n",
|
|||
|
" '''\n",
|
|||
|
" 对一个 epoch 的结果进行训练的流程,\n",
|
|||
|
" 一个 epoch 对应于从每个 city 出发进行一次 rollout\n",
|
|||
|
" '''\n",
|
|||
|
" tasks = [self.train_for_one_rollout(start_city_id) for start_city_id in range(self.num_cities)]\n",
|
|||
|
" await asyncio.gather(*tasks)\n",
|
|||
|
"\n",
|
|||
|
" async def train(self, num_epochs=1000, display=True):\n",
|
|||
|
" '''\n",
|
|||
|
" 总训练流程\n",
|
|||
|
" '''\n",
|
|||
|
" for epoch in range(num_epochs):\n",
|
|||
|
" await self.train_for_one_epoch()\n",
|
|||
|
" if display:\n",
|
|||
|
" self.draw(epoch)\n",
|
|||
|
"\n",
|
|||
|
" def draw(self, epoch):\n",
|
|||
|
" '''\n",
|
|||
|
" 绘图\n",
|
|||
|
" '''\n",
|
|||
|
" _ = plt.scatter(*self.cities)\n",
|
|||
|
" for fr, to in zip(self.best_path[:-1], self.best_path[1:]):\n",
|
|||
|
" x1, y1 = self.cities[:, fr]\n",
|
|||
|
" x2, y2 = self.cities[:, to]\n",
|
|||
|
" dx, dy = x2-x1, y2-y1\n",
|
|||
|
" plt.arrow(x1, y1, dx, dy, width=0.01*min(self.map_size), \n",
|
|||
|
" edgecolor='orange', facecolor='white', animated=True, \n",
|
|||
|
" length_includes_head=True)\n",
|
|||
|
" nrs = np.exp(self.qualities)\n",
|
|||
|
" for i in range(self.num_cities):\n",
|
|||
|
" nrs[i, i] = 0\n",
|
|||
|
" gap = np.abs(np.exp(self.normalizers) - nrs.sum(-1)).mean()\n",
|
|||
|
" plt.title(f'epoch {epoch}: path length = {self.best_path_length:.2f}, normalizer error = {gap:.3f}')\n",
|
|||
|
" plt.savefig('tsp.png')\n",
|
|||
|
" plt.show()\n",
|
|||
|
" clear_output(wait=True)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 3,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjEAAAGzCAYAAADe/0a6AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAACsWUlEQVR4nOzdd3zT5fbA8U/SvQd00hbKpgKizCJLQIaAIlAvThwXfyIOwIlXBdx7Xvd148QJKCACikBliuxN6aALSidNR/L9/fFktgVaaEnTnvfrlVeb5NvkSdskJ+c5z3l0mqZpCCGEEEK4GL2zByCEEEIIcTYkiBFCCCGES5IgRgghhBAuSYIYIYQQQrgkCWKEEEII4ZIkiBFCCCGES5IgRgghhBAuSYIYIYQQQrgkCWKEEEII4ZIkiDlLc+fORafTcezYMWcPpUnR6XTceeedZ/WzKSkp6HQ6Pv744/odVANo06YNY8eOdfYwhKjm999/R6fT8fvvv1svu+mmm2jTpo3TxiTEqUgQ4yR79+5l5syZ9O/fH29vb3Q6HSkpKTUeW1xczIwZM4iJicHLy4suXbrw9ttv13js8uXLGTBgAL6+voSEhDBp0qRT3q6zrFu3jrlz55Kfn+/soTSoXbt2MXfu3Eb3+6/q6aefpl+/foSFheHt7U2HDh2YMWMGubm51Y41mUw8//zzxMfH4+3tTffu3fnyyy9rvN3du3czatQo/P39CQ0N5YYbbqjxNu19/vnn6HQ6/P39z+qxTJ06FZ1OV2OA2KZNG3Q6XbXT7bffflb3JYQzvf322yQlJREXF4dOp+Omm26q8bjMzEweeughLr30UgICAqoFqGfyww8/MHLkSKKjo/Hy8iImJoZJkyaxY8eOasee7XPsdM/bM3Gv80+IepGcnMzrr79OQkICXbp0YevWrTUeZzQaGTlyJJs2bWL69Ol06NCBZcuWcccdd3DixAkefvhh67GLFy/myiuv5OKLL+bZZ5+lsLCQ1157jQEDBvD3338TFhZ2nh7d6a1bt4558+Zx0003ERwc7OzhNJhdu3Yxb948hgwZ0qg/xW7evJkePXowefJkAgIC2L17N++//z4///wzW7duxc/Pz3rsf/7zH5599lmmTp1K7969+emnn7j22mvR6XRMnjzZelx6ejqDBg0iKCiIp59+muLiYl588UW2b9/Ohg0b8PT0rDaO4uJiHnjgAYf7q4tNmzbx8ccf4+3tfcpjevTowb333utwWceOHc/q/pqT999/H5PJ5OxhCDvPPfccRUVF9OnTh8zMzFMet3fvXp577jk6dOhAt27dSE5OrtP9bN++nZCQEO655x5atmxJVlYWH374IX369CE5OZkLL7zQ4fi6Psdq87w9LU2clTlz5miAlpube1Y/f/z4ca2wsFDTNE174YUXNEA7fPhwteO++eYbDdA++OADh8snTpyoeXt7a9nZ2dbLEhIStPbt22tlZWXWy7Zu3arp9Xpt1qxZZzXOhnC6xwto06dPP6vbPXz4sAZoH3300bkNsJ4sWLBAA7RVq1ZVu65169bamDFjzv+gaunbb7/VAO3LL7+0Xpaenq55eHg4/H1MJpM2cOBALSYmRqusrLRePm3aNM3Hx0c7cuSI9bLly5drgPbuu+/WeJ8PPvig1qlTJ+26667T/Pz86jRek8mkJSYmarfccsspf7eN/Xd+KsXFxef1/latWnXK/9vzpaKiwuF1rD6VlJQ02H2ez79VSkqKZjKZNE3TND8/P23KlCk1HldYWKgdP35c07TTvybVRVZWlubu7q793//9n8PldX2O1eZ5eyaNejopIyODW265hYiICLy8vLjgggv48MMPHY6xzN9+/fXXPPzww0RGRuLn58cVV1xBWlpatdtcsGABPXv2xMfHh5YtW3L99deTkZFR7bg9e/Zw9dVXExYWho+PD506deI///lPtePy8/OtGYWgoCBuvvlmTp48ecbHFhoaSkBAwBmP+/PPPwEcPuVazhsMBn766ScA8vLy2LVrF1dddZXDp9wLL7yQLl268NVXXzn8fGZmJnv27KGiouK092+pM3nxxRd55ZVXaN26NT4+PgwePLhaOnHbtm3cdNNNtG3bFm9vbyIjI7nllls4fvy49Zi5c+dy//33AxAfH29NN1adcvnxxx/p2rWr9e++dOnSM/6uTmXPnj1MmjSJ0NBQvL296dWrFwsXLnQ45uOPP0an07F27VpmzZpFWFgYfn5+XHXVVdWmQEwmE3PnziU6OhpfX18uvfRSdu3aRZs2bawp3Y8//pikpCQALr30UuvjrJrGXbNmDX369MHb25u2bdvy6aefnvXjrE+WzJH9lN9PP/1ERUUFd9xxh/UynU7HtGnTSE9Pd/iE99133zF27Fji4uKslw0fPpyOHTvyzTffVLu//fv388orr/Dyyy/j7l73BPFnn33Gjh07eOqpp854bHl5OSUlJXW+j5pYarhq8//6999/M3r0aAIDA/H392fYsGH89ddfDsdY/g//+OMP7rjjDsLDw4mJiQFgyJAhdO3alW3btjF48GB8fX1p37493377LQB//PEHffv2tb5e/fbbbw63feTIEe644w46deqEj48PLVq0ICkpqVbTnVVrYoYMGVLjtEHVmrT8/HxmzJhBbGwsXl5etG/fnueee84hq2P/GvPqq6/Srl07vLy82LVr12nHNH/+fOtreWhoKJMnT672mm/5nW3evJlBgwbh6+vLww8/fMb7XLlyJQMHDsTPz4/g4GCuvPJKdu/e7XDblrrIXbt2ce211xISEsKAAQPO+LusL61bt0an053xuICAAEJDQ+v1vsPDw/H19T1lSUBtn2N1ed6eSqOdTsrOzqZfv37WF4mwsDCWLFnCrbfeSmFhITNmzHA4/qmnnkKn0/Hggw+Sk5PDq6++yvDhw9m6dSs+Pj6AeoG4+eab6d27N8888wzZ2dm89tprrF27lr///ts6tbFt2zYGDhyIh4cHt912G23atOHgwYMsWrSo2i/76quvJj4+nmeeeYYtW7bwv//9j/DwcJ577rl6+T2UlZXh5uZWLf3u6+sLqKmAqVOnUlZWBmB9rFWP3blzJ1lZWURGRgIwe/ZsPvnkEw4fPlyrqY5PP/2UoqIipk+fjsFg4LXXXmPo0KFs376diIgIQNXjHDp0iJtvvpnIyEh27tzJe++9x86dO/nrr7/Q6XRMmDCBffv28eWXX/LKK6/QsmVLAIeprjVr1vD9999zxx13EBAQwOuvv87EiRNJTU2lRYsWdfr97dy5k0suuYRWrVrx0EMP4efnxzfffMP48eP57rvvuOqqqxyOv+uuuwgJCWHOnDmkpKTw6quvcuedd/L1119bj5k9ezbPP/8848aNY+TIkfzzzz+MHDkSg8FgPWbQoEHcfffdvP766zz88MN06dIFwPoV4MCBA0yaNIlbb72VKVOm8OGHH3LTTTfRs2dPLrjggtM+rhMnTmA0Gs/4+H19fa3/K6ejaRrHjx+nsrKS/fv389BDD+Hm5saQIUOsx/z999/4+fk5PAaAPn36WK8fMGAAGRkZ5OTk0KtXr2r306dPH3755Zdql8+YMYNLL72Uyy+/vMYg53SKiop48MEHrR9iTmflypX4+vpiNBpp3bo1M2fO5J577qnT/VVVm//XnTt3MnDgQAIDA3nggQfw8PDg3XffZciQIdbgw94dd9xBWFgYjz32mMObwYkTJxg7diyTJ08mKSmJt99+m8mTJ/P5558zY8YMbr/9dq699lpeeOEFJk2aRFpamvXD0saNG1m3bh2TJ08mJiaGlJQU3n77bYYMGcKuXbtq9X9i8Z///Id///vfDpfNnz+fZcuWER4eDsDJkycZPHgwGRkZ/N///R9xcXGsW7eO2bNnk5mZyauvvurw8x999BEGg4HbbrsNLy+v077xPvXUUzz66KNcffXV/Pvf/yY3N5c33niDQYMGObyWAxw/fpzRo0czefJkrr/+euvr1anu87fffmP06NG0bduWuXPnUlpayhtvvMEll1zCli1bqr1eJiUl0aFDB55++mk0TTv
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 640x480 with 1 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"tsp = TSP()\n",
|
|||
|
"await tsp.train(200)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": []
|
|||
|
}
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"kernelspec": {
|
|||
|
"display_name": "PPO2",
|
|||
|
"language": "python",
|
|||
|
"name": "python3"
|
|||
|
},
|
|||
|
"language_info": {
|
|||
|
"codemirror_mode": {
|
|||
|
"name": "ipython",
|
|||
|
"version": 3
|
|||
|
},
|
|||
|
"file_extension": ".py",
|
|||
|
"mimetype": "text/x-python",
|
|||
|
"name": "python",
|
|||
|
"nbconvert_exporter": "python",
|
|||
|
"pygments_lexer": "ipython3",
|
|||
|
"version": "3.10.16"
|
|||
|
}
|
|||
|
},
|
|||
|
"nbformat": 4,
|
|||
|
"nbformat_minor": 2
|
|||
|
}
|