更新q_table
This commit is contained in:
parent
6f8fcd15b7
commit
1485fb2bd6
216
Q_learning/TSP.py
Normal file
216
Q_learning/TSP.py
Normal file
@ -0,0 +1,216 @@
|
||||
%matplotlib inline
|
||||
import pylab as plt
|
||||
from IPython.display import clear_output
|
||||
import numpy as np
|
||||
import asyncio
|
||||
|
||||
class TSP(object):
|
||||
'''
|
||||
用 Q-Learning 求解 TSP 问题
|
||||
作者 Surfer Zen @ https://www.zhihu.com/people/surfer-zen
|
||||
'''
|
||||
def __init__(self,
|
||||
num_cities=15,
|
||||
map_size=(800.0, 600.0),
|
||||
alpha=2,
|
||||
beta=1,
|
||||
learning_rate=0.001,
|
||||
eps=0.1,
|
||||
):
|
||||
'''
|
||||
Args:
|
||||
num_cities (int): 城市数目
|
||||
map_size (int, int): 地图尺寸(宽,高)
|
||||
alpha (float): 一个超参,值越大,越优先探索最近的点
|
||||
beta (float): 一个超参,值越大,越优先探索可能导向总距离最优的点
|
||||
learning_rate (float): 学习率
|
||||
eps (float): 探索率,值越大,探索性越强,但越难收敛
|
||||
'''
|
||||
self.num_cities =num_cities
|
||||
self.map_size = map_size
|
||||
self.alpha = alpha
|
||||
self.beta = beta
|
||||
self.eps = eps
|
||||
self.learning_rate = learning_rate
|
||||
self.cities = self.generate_cities()
|
||||
self.distances = self.get_dist_matrix()
|
||||
self.mean_distance = self.distances.mean()
|
||||
self.qualities = np.zeros([num_cities, num_cities])
|
||||
self.normalizers = np.zeros(num_cities)
|
||||
self.best_path = None
|
||||
self.best_path_length = np.inf
|
||||
|
||||
|
||||
def generate_cities(self):
|
||||
'''
|
||||
随机生成城市(坐标)
|
||||
Returns:
|
||||
cities: [[x1, x2, x3...], [y1, y2, y3...]] 城市坐标
|
||||
'''
|
||||
max_width, max_height = self.map_size
|
||||
cities = np.random.random([2, self.num_cities]) \
|
||||
* np.array([max_width, max_height]).reshape(2, -1)
|
||||
return cities
|
||||
|
||||
def get_dist_matrix(self):
|
||||
'''
|
||||
根据城市坐标,计算距离矩阵
|
||||
'''
|
||||
dist_matrix = np.zeros([self.num_cities, self.num_cities])
|
||||
for i in range(self.num_cities):
|
||||
for j in range(self.num_cities):
|
||||
if i == j:
|
||||
continue
|
||||
xi, xj = self.cities[0, i], self.cities[0, j]
|
||||
yi, yj = self.cities[1, i], self.cities[1, j]
|
||||
dist_matrix[i, j] = np.sqrt((xi-xj)**2 + (yi-yj)**2)
|
||||
return dist_matrix
|
||||
|
||||
def rollout(self, start_city_id=None):
|
||||
'''
|
||||
从 start_city 出发,根据策略,在城市间游走,直到所有城市都走了一遍
|
||||
'''
|
||||
cities_visited = []
|
||||
action_probs = []
|
||||
if start_city_id is None:
|
||||
start_city_id = np.random.randint(self.num_cities)
|
||||
current_city_id = start_city_id
|
||||
cities_visited.append(current_city_id)
|
||||
while len(cities_visited) < self.num_cities:
|
||||
current_city_id, action_prob = self.choose_next_city(cities_visited)
|
||||
cities_visited.append(current_city_id)
|
||||
action_probs.append(action_prob)
|
||||
cities_visited.append(cities_visited[0])
|
||||
action_probs.append(1.0)
|
||||
|
||||
path_length = self.calc_path_length(cities_visited)
|
||||
if path_length < self.best_path_length:
|
||||
self.best_path = cities_visited
|
||||
self.best_path_length = path_length
|
||||
rewards = self.calc_path_rewards(cities_visited, path_length)
|
||||
return cities_visited, action_probs, rewards
|
||||
|
||||
def choose_next_city(self, cities_visited):
|
||||
'''
|
||||
根据策略选择下一个城市
|
||||
'''
|
||||
current_city_id = cities_visited[-1]
|
||||
|
||||
# 对 quality 取指数,计算 softmax 概率用
|
||||
probabilities = np.exp(self.qualities[current_city_id])
|
||||
|
||||
# 将已经走过的城市概率设置为零
|
||||
for city_visited in cities_visited:
|
||||
probabilities[city_visited] = 0
|
||||
|
||||
# 计算 softmax 概率
|
||||
probabilities = probabilities/probabilities.sum()
|
||||
|
||||
if np.random.random() < self.eps:
|
||||
# 以 eps 概率按softmax概率密度进行随机采样
|
||||
next_city_id = np.random.choice(range(len(probabilities)), p=probabilities)
|
||||
else:
|
||||
# 以 (1 - eps) 概率选择当前最优策略
|
||||
next_city_id = probabilities.argmax()
|
||||
|
||||
# 计算当前决策/action 的概率
|
||||
if probabilities.argmax() == next_city_id:
|
||||
action_prob = probabilities[next_city_id]*self.eps + (1-self.eps)
|
||||
else:
|
||||
action_prob = probabilities[next_city_id]*self.eps
|
||||
|
||||
return next_city_id, action_prob
|
||||
|
||||
def calc_path_rewards(self, path, path_length):
|
||||
'''
|
||||
计算给定路径的奖励/rewards
|
||||
Args:
|
||||
path (list[int]): 路径,每个元素代表城市的 id
|
||||
path_length (float): 路径长路
|
||||
Returns:
|
||||
rewards: 每一步的奖励,总距离以及当前这一步的距离越大,奖励越小
|
||||
'''
|
||||
rewards = []
|
||||
for fr, to in zip(path[:-1], path[1:]):
|
||||
dist = self.distances[fr, to]
|
||||
reward = (self.mean_distance/path_length)**self.beta
|
||||
reward = reward*(self.mean_distance/dist)**self.alpha
|
||||
rewards.append(np.log(reward))
|
||||
return rewards
|
||||
|
||||
def calc_path_length(self, path):
|
||||
'''
|
||||
计算路径长度
|
||||
'''
|
||||
path_length = 0
|
||||
for fr, to in zip(path[:-1], path[1:]):
|
||||
path_length += self.distances[fr, to]
|
||||
return path_length
|
||||
|
||||
def calc_updates_for_one_rollout(self, path, action_probs, rewards):
|
||||
'''
|
||||
对于给定的一次 rollout 的结果,计算其对应的 qualities 和 normalizers
|
||||
'''
|
||||
qualities = []
|
||||
normalizers = []
|
||||
for fr, to, reward, action_prob in zip(path[:-1], path[1:], rewards, action_probs):
|
||||
log_action_probability = np.log(action_prob)
|
||||
qualities.append(- reward*log_action_probability)
|
||||
normalizers.append(- (reward + 1)*log_action_probability)
|
||||
return qualities, normalizers
|
||||
|
||||
def update(self, path, new_qualities, new_normalizers):
|
||||
'''
|
||||
用渐近平均的思想,对 qualities 和 normalizers 进行更新
|
||||
'''
|
||||
lr = self.learning_rate
|
||||
for fr, to, new_quality, new_normalizer in zip(
|
||||
path[:-1], path[1:], new_qualities, new_normalizers):
|
||||
self.normalizers[fr] = (1-lr)*self.normalizers[fr] + lr*new_normalizer
|
||||
self.qualities[fr, to] = (1-lr)*self.qualities[fr, to] + lr*new_quality
|
||||
|
||||
async def train_for_one_rollout(self, start_city_id):
|
||||
'''
|
||||
对一次 rollout 的结果进行训练的流程
|
||||
'''
|
||||
path, action_probs, rewards = self.rollout(start_city_id=start_city_id)
|
||||
new_qualities, new_normalizers = self.calc_updates_for_one_rollout(path, action_probs, rewards)
|
||||
self.update(path, new_qualities, new_normalizers)
|
||||
|
||||
async def train_for_one_epoch(self):
|
||||
'''
|
||||
对一个 epoch 的结果进行训练的流程,
|
||||
一个 epoch 对应于从每个 city 出发进行一次 rollout
|
||||
'''
|
||||
tasks = [self.train_for_one_rollout(start_city_id) for start_city_id in range(self.num_cities)]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def train(self, num_epochs=1000, display=True):
|
||||
'''
|
||||
总训练流程
|
||||
'''
|
||||
for epoch in range(num_epochs):
|
||||
await self.train_for_one_epoch()
|
||||
if display:
|
||||
self.draw(epoch)
|
||||
|
||||
def draw(self, epoch):
|
||||
'''
|
||||
绘图
|
||||
'''
|
||||
_ = plt.scatter(*self.cities)
|
||||
for fr, to in zip(self.best_path[:-1], self.best_path[1:]):
|
||||
x1, y1 = self.cities[:, fr]
|
||||
x2, y2 = self.cities[:, to]
|
||||
dx, dy = x2-x1, y2-y1
|
||||
plt.arrow(x1, y1, dx, dy, width=0.01*min(self.map_size),
|
||||
edgecolor='orange', facecolor='white', animated=True,
|
||||
length_includes_head=True)
|
||||
nrs = np.exp(self.qualities)
|
||||
for i in range(self.num_cities):
|
||||
nrs[i, i] = 0
|
||||
gap = np.abs(np.exp(self.normalizers) - nrs.sum(-1)).mean()
|
||||
plt.title(f'epoch {epoch}: path length = {self.best_path_length:.2f}, normalizer error = {gap:.3f}')
|
||||
plt.savefig('tsp.png')
|
||||
plt.show()
|
||||
clear_output(wait=True)
|
@ -3,232 +3,316 @@ import numpy as np
|
||||
import json
|
||||
import math
|
||||
import yaml
|
||||
# 参数设置
|
||||
STEP = 0.01
|
||||
VALUES = [round(i*STEP, 2) for i in range(101)] # 0.00~1.00
|
||||
ACTION_DELTA = [STEP, -STEP] # 增加或减少 0.01
|
||||
ACTIONS = [] # 每个动作为 (var_index, delta)
|
||||
for i in range(3):
|
||||
for delta in ACTION_DELTA:
|
||||
ACTIONS.append((i, delta))
|
||||
|
||||
ALPHA = 0.1 # 学习率
|
||||
GAMMA = 0.9 # 折扣因子
|
||||
EPSILON = 0.2 # 探索率
|
||||
NUM_EPISODES = 100
|
||||
from typing import Tuple, List, Dict, Any, Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
def f(state):
|
||||
"""
|
||||
计算切分比例的目标值 T(占位函数)
|
||||
:param row_cuts: 行切分比例
|
||||
:param col_cuts: 列切分比例
|
||||
:return: 目标值 T
|
||||
"""
|
||||
with open('params2.yml', 'r', encoding='utf-8') as file:
|
||||
params = yaml.safe_load(file)
|
||||
|
||||
H = params['H']
|
||||
W = params['W']
|
||||
num_cars = params['num_cars']
|
||||
@dataclass
|
||||
class Config:
|
||||
"""配置类,集中管理所有超参数"""
|
||||
# 学习参数
|
||||
ALPHA: float = 0.1 # 学习率
|
||||
GAMMA: float = 0.9 # 折扣因子
|
||||
EPSILON: float = 0.2 # 探索率
|
||||
NUM_EPISODES: int = 100 # 训练回合数
|
||||
|
||||
flight_time_factor = params['flight_time_factor']
|
||||
comp_time_factor = params['comp_time_factor']
|
||||
trans_time_factor = params['trans_time_factor']
|
||||
car_time_factor = params['car_time_factor']
|
||||
bs_time_factor = params['bs_time_factor']
|
||||
# 状态空间参数
|
||||
STEP: float = 0.01 # 状态变化步长
|
||||
VALUES: List[float] = None # 0.00~1.00的离散值列表
|
||||
|
||||
flight_energy_factor = params['flight_energy_factor']
|
||||
comp_energy_factor = params['comp_energy_factor']
|
||||
trans_energy_factor = params['trans_energy_factor']
|
||||
battery_energy_capacity = params['battery_energy_capacity']
|
||||
# 动作空间参数
|
||||
ACTION_DELTA: List[float] = None # 动作变化量
|
||||
|
||||
col_cuts = list(state)
|
||||
col_cuts.insert(0, 0)
|
||||
col_cuts.append(1)
|
||||
row_cuts = [0, 0.5, 1]
|
||||
rectangles = []
|
||||
for i in range(len(row_cuts) - 1):
|
||||
for j in range(len(col_cuts) - 1):
|
||||
d = (col_cuts[j+1] - col_cuts[j]) * W * \
|
||||
(row_cuts[i+1] - row_cuts[i]) * H
|
||||
rho_time_limit = (flight_time_factor - trans_time_factor) / \
|
||||
(comp_time_factor - trans_time_factor)
|
||||
rho_energy_limit = (battery_energy_capacity - flight_energy_factor * d - trans_energy_factor * d) / (comp_energy_factor * d - trans_energy_factor * d)
|
||||
if rho_energy_limit < 0:
|
||||
return 100000
|
||||
rho = min(rho_time_limit, rho_energy_limit)
|
||||
# 环境参数
|
||||
MIN_IMPROVEMENT: float = 0.001 # 最小改善阈值
|
||||
MAX_NO_IMPROVEMENT: int = 10 # 最大允许连续未改善次数
|
||||
TARGET_THRESHOLD: float = 10000 # 目标函数值的可接受阈值
|
||||
|
||||
flight_time = flight_time_factor * d
|
||||
bs_time = bs_time_factor * (1 - rho) * d
|
||||
def __post_init__(self):
|
||||
"""初始化依赖参数"""
|
||||
self.VALUES = [round(i*self.STEP, 2) for i in range(101)]
|
||||
self.ACTION_DELTA = [self.STEP, -self.STEP]
|
||||
# TODO 需要修改4个变量
|
||||
self.ACTIONS = [(i, delta) for i in range(4)
|
||||
for delta in self.ACTION_DELTA]
|
||||
|
||||
rectangles.append({
|
||||
'flight_time': flight_time,
|
||||
'bs_time': bs_time,
|
||||
'center': ((row_cuts[i] + row_cuts[i+1]) / 2.0 * H,
|
||||
(col_cuts[j] + col_cuts[j+1]) / 2.0 * W)
|
||||
})
|
||||
|
||||
mortorcade_time_lt = []
|
||||
for idx in range(num_cars):
|
||||
car_path = car_paths[idx]
|
||||
|
||||
flight_time = sum(rectangles[point]['flight_time']
|
||||
for point in car_path)
|
||||
bs_time = sum(rectangles[point]['bs_time'] for point in car_path)
|
||||
|
||||
car_time = 0
|
||||
for i in range(len(car_path) - 1):
|
||||
first_point = car_path[i]
|
||||
second_point = car_path[i + 1]
|
||||
car_time += math.dist(
|
||||
rectangles[first_point]['center'], rectangles[second_point]['center']) * car_time_factor
|
||||
car_time += math.dist(rectangles[car_path[0]]['center'],
|
||||
[H / 2, W / 2]) * car_time_factor
|
||||
car_time += math.dist(rectangles[car_path[-1]]['center'],
|
||||
[H / 2, W / 2]) * car_time_factor
|
||||
mortorcade_time_lt.append(max(car_time + flight_time, bs_time))
|
||||
|
||||
return max(mortorcade_time_lt)
|
||||
|
||||
# 环境类:定义状态转移与奖励
|
||||
class FunctionEnv:
|
||||
def __init__(self, initial_state):
|
||||
self.state = initial_state # 初始状态 (x1,x2,x3)
|
||||
self.best_value = float('inf') # 记录最佳值
|
||||
self.no_improvement_count = 0 # 记录连续未改善的次数
|
||||
self.last_state = None # 记录上一个状态
|
||||
self.min_improvement = 0.001 # 最小改善阈值
|
||||
self.max_no_improvement = 10 # 最大允许连续未改善次数
|
||||
self.target_threshold = 10000 # 目标函数值的可接受阈值
|
||||
|
||||
def step(self, action):
|
||||
# action: (var_index, delta)
|
||||
"""环境类:定义状态转移与奖励"""
|
||||
|
||||
def __init__(self, params_file: str, initial_state: Tuple[float, float, float], cut_index: int, car_paths: List[List[int]], config: Config):
|
||||
self.params_file = params_file
|
||||
with open(self.params_file + '.yml', 'r', encoding='utf-8') as file:
|
||||
params = yaml.safe_load(file)
|
||||
|
||||
self.H = params['H']
|
||||
self.W = params['W']
|
||||
self.num_cars = params['num_cars']
|
||||
|
||||
self.flight_time_factor = params['flight_time_factor']
|
||||
self.comp_time_factor = params['comp_time_factor']
|
||||
self.trans_time_factor = params['trans_time_factor']
|
||||
self.car_time_factor = params['car_time_factor']
|
||||
self.bs_time_factor = params['bs_time_factor']
|
||||
|
||||
self.flight_energy_factor = params['flight_energy_factor']
|
||||
self.comp_energy_factor = params['comp_energy_factor']
|
||||
self.trans_energy_factor = params['trans_energy_factor']
|
||||
self.battery_energy_capacity = params['battery_energy_capacity']
|
||||
|
||||
self.state = initial_state
|
||||
self.cut_index = cut_index
|
||||
self.config = config
|
||||
self.car_paths = car_paths
|
||||
self.best_value = float('inf')
|
||||
self.no_improvement_count = 0
|
||||
self.last_state = None
|
||||
|
||||
def step(self, action: Tuple[int, float]) -> Tuple[Tuple[float, float, float], float, bool]:
|
||||
"""
|
||||
执行一步动作
|
||||
|
||||
Args:
|
||||
action: (变量索引, 变化量)的元组
|
||||
|
||||
Returns:
|
||||
Tuple[Tuple[float, float, float], float, bool]: (下一个状态, 奖励, 是否结束)
|
||||
"""
|
||||
var_index, delta = action
|
||||
new_state = list(self.state)
|
||||
new_state[var_index] = round(new_state[var_index] + delta, 2)
|
||||
# 保证取值在0-1范围内
|
||||
if new_state[var_index] < 0 or new_state[var_index] > 1:
|
||||
return self.state, -10000.0, True # episode结束
|
||||
# 检查约束:x1 < x2 < x3
|
||||
if not (0 < new_state[0] < new_state[1] < new_state[2] < 1):
|
||||
row_cuts_state = new_state[:self.cut_index]
|
||||
col_cuts_state = new_state[self.cut_index:]
|
||||
|
||||
# 检查约束条件
|
||||
if not self._is_valid_state(row_cuts_state) or not self._is_valid_state(col_cuts_state):
|
||||
return self.state, -10000.0, True
|
||||
|
||||
current_value = self.calculate_objective(
|
||||
row_cuts_state, col_cuts_state)
|
||||
|
||||
# 检查终止条件
|
||||
if self._should_terminate(new_state, current_value):
|
||||
return new_state, 12000 - current_value, True
|
||||
|
||||
self._update_state(new_state, current_value)
|
||||
return new_state, 12000 - current_value, False
|
||||
|
||||
def _is_valid_state(self, state: List[float]) -> bool:
|
||||
"""
|
||||
检查状态是否满足约束条件
|
||||
确保列表中的元素严格递增且在(0,1)范围内
|
||||
|
||||
next_state = tuple(new_state)
|
||||
current_value = f(next_state)
|
||||
|
||||
# 检查是否达到目标阈值
|
||||
if current_value < self.target_threshold:
|
||||
return next_state, 12000 - current_value, True
|
||||
Args:
|
||||
state: 需要检查的状态列表
|
||||
|
||||
# 检查状态变化是否很小
|
||||
Returns:
|
||||
bool: 是否满足约束条件
|
||||
"""
|
||||
# 检查列表是否为空
|
||||
if not state:
|
||||
return False
|
||||
|
||||
# 检查所有元素是否在(0,1)范围内
|
||||
if not all(0 < x < 1 for x in state):
|
||||
return False
|
||||
|
||||
# 检查是否严格递增
|
||||
for i in range(len(state) - 1):
|
||||
if state[i] >= state[i + 1]:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _should_terminate(self, state: Tuple[float, float, float], value: float) -> bool:
|
||||
"""检查是否应该终止"""
|
||||
if value < self.config.TARGET_THRESHOLD:
|
||||
return True
|
||||
|
||||
if self.last_state is not None:
|
||||
state_diff = sum(abs(a - b) for a, b in zip(next_state, self.last_state))
|
||||
if state_diff < self.min_improvement:
|
||||
state_diff = sum(abs(a - b)
|
||||
for a, b in zip(state, self.last_state))
|
||||
if state_diff < self.config.MIN_IMPROVEMENT:
|
||||
self.no_improvement_count += 1
|
||||
else:
|
||||
self.no_improvement_count = 0
|
||||
|
||||
# 检查是否有改善
|
||||
if current_value < self.best_value:
|
||||
self.best_value = current_value
|
||||
|
||||
if value < self.best_value:
|
||||
self.best_value = value
|
||||
self.no_improvement_count = 0
|
||||
else:
|
||||
self.no_improvement_count += 1
|
||||
|
||||
# 如果连续多次没有改善,结束episode
|
||||
if self.no_improvement_count >= self.max_no_improvement:
|
||||
return next_state, 12000 - current_value, True
|
||||
|
||||
self.last_state = next_state
|
||||
self.state = next_state
|
||||
return next_state, 12000 - current_value, False
|
||||
|
||||
def reset(self, state):
|
||||
|
||||
return self.no_improvement_count >= self.config.MAX_NO_IMPROVEMENT
|
||||
|
||||
def _update_state(self, state: Tuple[float, float, float], value: float) -> None:
|
||||
"""更新状态"""
|
||||
self.last_state = self.state
|
||||
self.state = state
|
||||
|
||||
def calculate_objective(self, row_cuts, col_cuts):
|
||||
"""
|
||||
计算切分比例的目标值 T(占位函数)
|
||||
:param row_cuts: 行切分比例
|
||||
:param col_cuts: 列切分比例
|
||||
:return: 目标值 T
|
||||
"""
|
||||
row_cuts = [0] + row_cuts + [1]
|
||||
col_cuts = [0] + col_cuts + [1]
|
||||
rectangles = []
|
||||
for i in range(len(row_cuts) - 1):
|
||||
for j in range(len(col_cuts) - 1):
|
||||
d = (col_cuts[j+1] - col_cuts[j]) * self.W * \
|
||||
(row_cuts[i+1] - row_cuts[i]) * self.H
|
||||
rho_time_limit = (self.flight_time_factor - self.trans_time_factor) / \
|
||||
(self.comp_time_factor - self.trans_time_factor)
|
||||
rho_energy_limit = (self.battery_energy_capacity - self.flight_energy_factor * d - self.trans_energy_factor * d) / \
|
||||
(self.comp_energy_factor * d -
|
||||
self.trans_energy_factor * d)
|
||||
if rho_energy_limit < 0:
|
||||
return 1000000
|
||||
rho = min(rho_time_limit, rho_energy_limit)
|
||||
|
||||
flight_time = self.flight_time_factor * d
|
||||
bs_time = self.bs_time_factor * (1 - rho) * d
|
||||
|
||||
rectangles.append({
|
||||
'flight_time': flight_time,
|
||||
'bs_time': bs_time,
|
||||
'center': ((row_cuts[i] + row_cuts[i+1]) / 2.0 * self.H,
|
||||
(col_cuts[j] + col_cuts[j+1]) / 2.0 * self.W)
|
||||
})
|
||||
|
||||
mortorcade_time_lt = []
|
||||
for idx in range(self.num_cars):
|
||||
car_path = self.car_paths[idx]
|
||||
|
||||
flight_time = sum(rectangles[point]['flight_time']
|
||||
for point in car_path)
|
||||
bs_time = sum(rectangles[point]['bs_time'] for point in car_path)
|
||||
|
||||
car_time = 0
|
||||
for i in range(len(car_path) - 1):
|
||||
first_point = car_path[i]
|
||||
second_point = car_path[i + 1]
|
||||
car_time += math.dist(
|
||||
rectangles[first_point]['center'], rectangles[second_point]['center']) * self.car_time_factor
|
||||
car_time += math.dist(rectangles[car_path[0]]['center'],
|
||||
[self.H / 2, self.W / 2]) * self.car_time_factor
|
||||
car_time += math.dist(rectangles[car_path[-1]]['center'],
|
||||
[self.H / 2, self.W / 2]) * self.car_time_factor
|
||||
mortorcade_time_lt.append(max(car_time + flight_time, bs_time))
|
||||
|
||||
return max(mortorcade_time_lt)
|
||||
|
||||
def reset(self, state: Tuple[float, float, float]) -> Tuple[float, float, float]:
|
||||
"""重置环境状态"""
|
||||
self.state = state
|
||||
return self.state
|
||||
|
||||
# 初始化 Q-table:使用字典表示,key 为状态 tuple,value 为 dict: action->Q值
|
||||
Q_table = {}
|
||||
|
||||
def get_Q(state, action):
|
||||
if state not in Q_table:
|
||||
Q_table[state] = {a: 0.0 for a in ACTIONS}
|
||||
return Q_table[state][action]
|
||||
class QLearning:
|
||||
"""Q-learning算法实现"""
|
||||
|
||||
def set_Q(state, action, value):
|
||||
if state not in Q_table:
|
||||
Q_table[state] = {a: 0.0 for a in ACTIONS}
|
||||
Q_table[state][action] = value
|
||||
def __init__(self, config: Config):
|
||||
self.config = config
|
||||
self.Q_table: Dict[Tuple[float, float, float],
|
||||
Dict[Tuple[int, float], float]] = {}
|
||||
|
||||
def choose_action(state, epsilon):
|
||||
# ε-greedy 策略
|
||||
if random.random() < epsilon:
|
||||
return random.choice(ACTIONS)
|
||||
else:
|
||||
if state not in Q_table:
|
||||
Q_table[state] = {a: 0.0 for a in ACTIONS}
|
||||
# 返回Q值最大的动作
|
||||
return max(Q_table[state].items(), key=lambda x: x[1])[0]
|
||||
def get_Q(self, state: Tuple[float, float, float], action: Tuple[int, float]) -> float:
|
||||
"""获取Q值"""
|
||||
if state not in self.Q_table:
|
||||
self.Q_table[state] = {a: 0.0 for a in self.config.ACTIONS}
|
||||
return self.Q_table[state][action]
|
||||
|
||||
def load_initial_solution(file_path):
|
||||
def set_Q(self, state: Tuple[float, float, float], action: Tuple[int, float], value: float) -> None:
|
||||
"""设置Q值"""
|
||||
if state not in self.Q_table:
|
||||
self.Q_table[state] = {a: 0.0 for a in self.config.ACTIONS}
|
||||
self.Q_table[state][action] = value
|
||||
|
||||
def choose_action(self, state: Tuple[float, float, float], epsilon: float) -> Tuple[int, float]:
|
||||
"""选择动作(ε-greedy策略)"""
|
||||
if random.random() < epsilon:
|
||||
return random.choice(self.config.ACTIONS)
|
||||
if state not in self.Q_table:
|
||||
self.Q_table[state] = {a: 0.0 for a in self.config.ACTIONS}
|
||||
return max(self.Q_table[state].items(), key=lambda x: x[1])[0]
|
||||
|
||||
|
||||
def load_initial_solution(file_path: str) -> Tuple[List[float], List[float], List[List[int]]]:
|
||||
"""
|
||||
从 JSON 文件加载初始解
|
||||
:param file_path: JSON 文件路径
|
||||
:return: 行切分比例、列切分比例
|
||||
从JSON文件加载初始解
|
||||
|
||||
Args:
|
||||
file_path: JSON文件路径
|
||||
|
||||
Returns:
|
||||
Tuple[List[float], List[float], List[List[int]]]: (行切分比例, 列切分比例, 车辆路径)
|
||||
"""
|
||||
with open(file_path, 'r', encoding='utf-8') as file:
|
||||
data = json.load(file)
|
||||
row_cuts = data['row_boundaries']
|
||||
col_cuts = data['col_boundaries']
|
||||
car_paths = data['car_paths']
|
||||
return row_cuts, col_cuts, car_paths
|
||||
return data['row_boundaries'], data['col_boundaries'], data['car_paths']
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
# 初始化配置
|
||||
config = Config()
|
||||
random.seed(42)
|
||||
|
||||
# ---------------------------
|
||||
# 需要修改的超参数
|
||||
# ---------------------------
|
||||
# 加载初始解
|
||||
solution_path = r"solutions\trav_ga_params2_parallel.json"
|
||||
params_file = r"params2"
|
||||
|
||||
initial_row_cuts, initial_col_cuts, car_paths = load_initial_solution(
|
||||
solution_path)
|
||||
|
||||
initial_state = (0.2, 0.4, 0.7)
|
||||
# Q-learning 主循环
|
||||
env = FunctionEnv(initial_state)
|
||||
initial_state = initial_row_cuts[1:-1] + initial_col_cuts[1:-1]
|
||||
cut_index = len(initial_row_cuts) - 2
|
||||
|
||||
for episode in range(NUM_EPISODES):
|
||||
print(f"Episode {episode + 1} of {NUM_EPISODES}")
|
||||
# 初始化环境和Q-learning
|
||||
env = FunctionEnv(params_file, initial_state, cut_index, car_paths, config)
|
||||
q_learning = QLearning(config)
|
||||
|
||||
# 训练循环
|
||||
for episode in range(config.NUM_EPISODES):
|
||||
print(f"Episode {episode + 1} of {config.NUM_EPISODES}")
|
||||
state = env.reset(initial_state)
|
||||
done = False
|
||||
|
||||
|
||||
while not done:
|
||||
# 选择动作
|
||||
action = choose_action(state, EPSILON)
|
||||
# 环境执行动作
|
||||
action = q_learning.choose_action(tuple(state), config.EPSILON)
|
||||
next_state, reward, done = env.step(action)
|
||||
# Q-learning 更新:Q(s,a) = Q(s,a) + α [r + γ * max_a' Q(s', a') - Q(s,a)]
|
||||
if next_state not in Q_table:
|
||||
Q_table[next_state] = {a: 0.0 for a in ACTIONS}
|
||||
max_next_Q = max(Q_table[next_state].values())
|
||||
current_Q = get_Q(state, action)
|
||||
new_Q = current_Q + ALPHA * (reward + GAMMA * max_next_Q - current_Q)
|
||||
set_Q(state, action, new_Q)
|
||||
next_state = tuple(next_state)
|
||||
|
||||
# Q-learning更新
|
||||
if next_state not in q_learning.Q_table:
|
||||
q_learning.Q_table[next_state] = {
|
||||
a: 0.0 for a in config.ACTIONS}
|
||||
max_next_Q = max(q_learning.Q_table[next_state].values())
|
||||
current_Q = q_learning.get_Q(tuple(state), action)
|
||||
new_Q = current_Q + config.ALPHA * \
|
||||
(reward + config.GAMMA * max_next_Q - current_Q)
|
||||
q_learning.set_Q(tuple(state), action, new_Q)
|
||||
state = next_state
|
||||
|
||||
# 可逐步减小探索率
|
||||
EPSILON = max(0.01, EPSILON * 0.999)
|
||||
# 更新探索率
|
||||
config.EPSILON = max(0.01, config.EPSILON * 0.999)
|
||||
|
||||
# 输出 Q-table 中最佳策略的状态和值
|
||||
# 输出最优解
|
||||
best_state = None
|
||||
best_value = float('inf')
|
||||
for state in Q_table:
|
||||
# 这里根据函数值来评价解的好坏
|
||||
state_value = f(state)
|
||||
for state in q_learning.Q_table:
|
||||
state = list(state)
|
||||
state_value = env.calculate_objective(
|
||||
state[:cut_index], state[cut_index:])
|
||||
if state_value < best_value:
|
||||
best_value = state_value
|
||||
best_state = state
|
||||
|
||||
print("找到的最优状态:", best_state, "对应函数值:", best_value)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
Loading…
Reference in New Issue
Block a user