修改dqn奖励
This commit is contained in:
parent
58952f1fdb
commit
db04a87ffd
@ -142,7 +142,7 @@ def main():
|
|||||||
if total_steps % 1000 == 0:
|
if total_steps % 1000 == 0:
|
||||||
agent.exp_noise *= opt.noise_decay
|
agent.exp_noise *= opt.noise_decay
|
||||||
if total_steps % opt.eval_interval == 0:
|
if total_steps % opt.eval_interval == 0:
|
||||||
score = evaluate_policy(eval_env, agent, turns=3)
|
score = evaluate_policy(eval_env, agent, turns=1)
|
||||||
if opt.write:
|
if opt.write:
|
||||||
writer.add_scalar(
|
writer.add_scalar(
|
||||||
'ep_r', score, global_step=total_steps)
|
'ep_r', score, global_step=total_steps)
|
||||||
|
@ -1,23 +1,80 @@
|
|||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
import copy
|
||||||
|
|
||||||
|
|
||||||
def evaluate_policy(env, agent, turns=3):
|
def evaluate_policy(env, agent, turns=3):
|
||||||
|
"""
|
||||||
|
评估策略
|
||||||
|
Args:
|
||||||
|
env: 环境对象
|
||||||
|
agent: 智能体对象
|
||||||
|
turns: 评估轮数
|
||||||
|
Returns:
|
||||||
|
int: 平均得分
|
||||||
|
"""
|
||||||
total_scores = 0
|
total_scores = 0
|
||||||
for j in range(turns):
|
|
||||||
|
# for j in range(turns):
|
||||||
s = env.reset()
|
s = env.reset()
|
||||||
done = False
|
done = False
|
||||||
action_series = []
|
eval_info = {'action_series': [],
|
||||||
|
# 'state_series': [],
|
||||||
|
'reward_series': []}
|
||||||
|
info_lt = []
|
||||||
|
|
||||||
while not done:
|
while not done:
|
||||||
# Take deterministic actions at test time
|
|
||||||
a = agent.select_action(s, deterministic=True)
|
a = agent.select_action(s, deterministic=True)
|
||||||
s_next, r, dw, tr, info = env.step(a)
|
s_next, r, dw, tr, info = env.step(a)
|
||||||
done = (dw or tr)
|
done = (dw or tr)
|
||||||
action_series.append(a)
|
|
||||||
|
eval_info['action_series'].append(a)
|
||||||
|
eval_info['reward_series'].append(r)
|
||||||
|
info_lt.append(copy.deepcopy(info))
|
||||||
|
|
||||||
total_scores += r
|
total_scores += r
|
||||||
s = s_next
|
s = s_next
|
||||||
print('action series: ', action_series)
|
|
||||||
print('state: ', s)
|
print(eval_info)
|
||||||
|
save_best_solution(info_lt)
|
||||||
|
|
||||||
return int(total_scores/turns)
|
return int(total_scores/turns)
|
||||||
|
|
||||||
|
|
||||||
|
def save_best_solution(info_lt):
|
||||||
|
# 找出这一轮中最优的解
|
||||||
|
best_info = min(info_lt, key=lambda x: x['best_time'])
|
||||||
|
|
||||||
|
# 读取已有的最优解
|
||||||
|
try:
|
||||||
|
with open('solutions/dqn_params_50_50_3.json', 'r') as f:
|
||||||
|
saved_solution = json.load(f)
|
||||||
|
saved_time = saved_solution['best_time']
|
||||||
|
except FileNotFoundError:
|
||||||
|
saved_time = float('inf')
|
||||||
|
|
||||||
|
# 如果新的解更好,则更新json文件
|
||||||
|
if best_info['best_time'] < saved_time:
|
||||||
|
best_solution = {
|
||||||
|
'best_time': best_info['best_time'],
|
||||||
|
'row_cuts': best_info['row_cuts'],
|
||||||
|
'col_cuts': best_info['col_cuts'],
|
||||||
|
'best_path': best_info['best_path'],
|
||||||
|
'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||||
|
}
|
||||||
|
|
||||||
|
with open('solutions/dqn_params_50_50_3.json', 'w') as f:
|
||||||
|
json.dump(best_solution, f, indent=4)
|
||||||
|
|
||||||
|
print(f"发现新的最优解!时间: {best_info['best_time']}")
|
||||||
|
|
||||||
|
|
||||||
|
def compare_lists(list1, list2):
|
||||||
|
return len(list1) == len(list2) and all(a == b for a, b in zip(list1, list2))
|
||||||
|
|
||||||
# You can just ignore this funciton. Is not related to the RL.
|
# You can just ignore this funciton. Is not related to the RL.
|
||||||
|
|
||||||
|
|
||||||
def str2bool(v):
|
def str2bool(v):
|
||||||
'''transfer str to bool for argparse'''
|
'''transfer str to bool for argparse'''
|
||||||
if isinstance(v, bool):
|
if isinstance(v, bool):
|
||||||
|
2
GA/ga.py
2
GA/ga.py
@ -16,7 +16,7 @@ class GA(object):
|
|||||||
self.location = data
|
self.location = data
|
||||||
self.to_process_idx = to_process_idx
|
self.to_process_idx = to_process_idx
|
||||||
self.rectangles = rectangles
|
self.rectangles = rectangles
|
||||||
self.epochs = 1000
|
self.epochs = 1500
|
||||||
self.ga_choose_ratio = 0.2
|
self.ga_choose_ratio = 0.2
|
||||||
self.mutate_ratio = 0.05
|
self.mutate_ratio = 0.05
|
||||||
# fruits中存每一个个体是下标的list
|
# fruits中存每一个个体是下标的list
|
||||||
|
@ -18,11 +18,11 @@ class PartitionEnv(gym.Env):
|
|||||||
##############################
|
##############################
|
||||||
# 可能需要手动修改的超参数
|
# 可能需要手动修改的超参数
|
||||||
##############################
|
##############################
|
||||||
self.params = 'params2'
|
self.params = 'params_50_50_3'
|
||||||
self.ORI_ROW_CUTS = [0, 0.2, 0.4, 0.7, 1]
|
self.ORI_ROW_CUTS = [0, 0.2, 0.4, 0.7, 1]
|
||||||
self.ORI_COL_CUTS = [0, 0.5, 1]
|
self.ORI_COL_CUTS = [0, 0.5, 1]
|
||||||
self.CUT_NUM = 4
|
self.CUT_NUM = 4
|
||||||
self.BASE_LINE = 9100
|
self.BASE_LINE = 9051.16
|
||||||
self.MAX_ADJUST_STEP = 50
|
self.MAX_ADJUST_STEP = 50
|
||||||
self.ADJUST_THRESHOLD = 0.1
|
self.ADJUST_THRESHOLD = 0.1
|
||||||
# self.mTSP_STEPS = 10000
|
# self.mTSP_STEPS = 10000
|
||||||
@ -115,11 +115,13 @@ class PartitionEnv(gym.Env):
|
|||||||
reward = self.calc_reward(best_time)
|
reward = self.calc_reward(best_time)
|
||||||
self.adjust_step += 1
|
self.adjust_step += 1
|
||||||
state = np.array(self.row_cuts + self.col_cuts)
|
state = np.array(self.row_cuts + self.col_cuts)
|
||||||
|
info = {'row_cuts': self.row_cuts, 'col_cuts': self.col_cuts,
|
||||||
|
'best_path': self.best_path, 'best_time': best_time}
|
||||||
|
|
||||||
if self.adjust_step < self.MAX_ADJUST_STEP:
|
if self.adjust_step < self.MAX_ADJUST_STEP:
|
||||||
return state, reward, False, False, {}
|
return state, reward, False, False, info
|
||||||
else:
|
else:
|
||||||
return state, reward, True, False, {}
|
return state, reward, True, False, info
|
||||||
|
|
||||||
def if_valid_partition(self):
|
def if_valid_partition(self):
|
||||||
rectangles = []
|
rectangles = []
|
||||||
@ -221,7 +223,11 @@ class PartitionEnv(gym.Env):
|
|||||||
|
|
||||||
def calc_reward(self, best_time):
|
def calc_reward(self, best_time):
|
||||||
"""
|
"""
|
||||||
计算奖励
|
计算奖励:
|
||||||
|
1. 如果时间小于基准线,给予正奖励
|
||||||
|
2. 如果时间大于基准线,给予负奖励
|
||||||
|
3. 保持归一化和折扣因子
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
best_time (float): 当前路径的时间
|
best_time (float): 当前路径的时间
|
||||||
Returns:
|
Returns:
|
||||||
@ -229,14 +235,15 @@ class PartitionEnv(gym.Env):
|
|||||||
"""
|
"""
|
||||||
time_diff = self.BASE_LINE - best_time
|
time_diff = self.BASE_LINE - best_time
|
||||||
|
|
||||||
# 归一化时间差
|
# 使用tanh归一化,确保time_diff=0时,normalized_diff=0
|
||||||
normalized_diff = 1 / (1 + np.exp(-time_diff/20))
|
# tanh在变量值为2时,就非常接近1了。最大的time_diff为400
|
||||||
|
normalized_diff = np.tanh(time_diff / 200) # 20是缩放因子,可调整
|
||||||
|
|
||||||
# 计算轮次权重
|
# 计算轮次权重(折扣因子)
|
||||||
step_weight = 1 / (1 + np.exp(-self.adjust_step/10))
|
step_weight = 1 / (1 + np.exp(-self.adjust_step/10))
|
||||||
|
|
||||||
# 计算最终奖励(添加缩放因子)
|
# 计算最终奖励
|
||||||
reward = normalized_diff * step_weight * 10 # 10是缩放因子
|
reward = normalized_diff * step_weight # 10是缩放因子
|
||||||
|
|
||||||
return reward
|
return reward
|
||||||
|
|
||||||
|
@ -12,7 +12,7 @@ random.seed(42)
|
|||||||
# ---------------------------
|
# ---------------------------
|
||||||
# 需要修改的超参数
|
# 需要修改的超参数
|
||||||
# ---------------------------
|
# ---------------------------
|
||||||
num_iterations = 100000000
|
num_iterations = 3000000000
|
||||||
# 随机生成分区的行分段数与列分段数
|
# 随机生成分区的行分段数与列分段数
|
||||||
R = random.randint(0, 3) # 行分段数
|
R = random.randint(0, 3) # 行分段数
|
||||||
C = random.randint(0, 3) # 列分段数
|
C = random.randint(0, 3) # 列分段数
|
||||||
@ -47,13 +47,15 @@ best_solution = None
|
|||||||
|
|
||||||
for iteration in tqdm(range(num_iterations), desc="蒙特卡洛模拟进度"):
|
for iteration in tqdm(range(num_iterations), desc="蒙特卡洛模拟进度"):
|
||||||
# 直接切值
|
# 直接切值
|
||||||
horiz = [random.random() for _ in range(R)]
|
# horiz = [random.random() for _ in range(R)]
|
||||||
|
horiz = [random.randint(1, 999)/1000 for _ in range(R)]
|
||||||
horiz = sorted(set(horiz))
|
horiz = sorted(set(horiz))
|
||||||
horiz = horiz if horiz else []
|
horiz = horiz if horiz else []
|
||||||
row_boundaries = [0] + horiz + [1]
|
row_boundaries = [0] + horiz + [1]
|
||||||
row_boundaries = [boundary * H for boundary in row_boundaries]
|
row_boundaries = [boundary * H for boundary in row_boundaries]
|
||||||
|
|
||||||
vert = [random.random() for _ in range(C)]
|
# vert = [random.random() for _ in range(C)]
|
||||||
|
vert = [random.randint(1, 999)/1000 for _ in range(C)]
|
||||||
vert = sorted(set(vert))
|
vert = sorted(set(vert))
|
||||||
vert = vert if vert else []
|
vert = vert if vert else []
|
||||||
col_boundaries = [0] + vert + [1]
|
col_boundaries = [0] + vert + [1]
|
||||||
@ -151,8 +153,6 @@ for iteration in tqdm(range(num_iterations), desc="蒙特卡洛模拟进度"):
|
|||||||
|
|
||||||
T_max = max(T_k_list) # 整体目标 T 为各系统中最大的 T_k
|
T_max = max(T_k_list) # 整体目标 T 为各系统中最大的 T_k
|
||||||
|
|
||||||
# TODO 没有限制系统的总能耗
|
|
||||||
|
|
||||||
if T_max < best_T:
|
if T_max < best_T:
|
||||||
best_T = T_max
|
best_T = T_max
|
||||||
best_solution = {
|
best_solution = {
|
||||||
@ -168,6 +168,7 @@ for iteration in tqdm(range(num_iterations), desc="蒙特卡洛模拟进度"):
|
|||||||
'flight_time': total_flight_time,
|
'flight_time': total_flight_time,
|
||||||
'bs_time': total_bs_time
|
'bs_time': total_bs_time
|
||||||
}
|
}
|
||||||
|
print(iteration)
|
||||||
|
|
||||||
# ---------------------------
|
# ---------------------------
|
||||||
# 输出最佳方案
|
# 输出最佳方案
|
||||||
|
29
solutions/dqn_params_50_50_3.json
Normal file
29
solutions/dqn_params_50_50_3.json
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
{
|
||||||
|
"best_time": 9051.162633521315,
|
||||||
|
"row_cuts": [
|
||||||
|
0,
|
||||||
|
0.21000000000000002,
|
||||||
|
0.4,
|
||||||
|
0.7,
|
||||||
|
1
|
||||||
|
],
|
||||||
|
"col_cuts": [
|
||||||
|
0,
|
||||||
|
0.5,
|
||||||
|
1
|
||||||
|
],
|
||||||
|
"best_path": [
|
||||||
|
7,
|
||||||
|
8,
|
||||||
|
0,
|
||||||
|
6,
|
||||||
|
2,
|
||||||
|
4,
|
||||||
|
10,
|
||||||
|
9,
|
||||||
|
5,
|
||||||
|
3,
|
||||||
|
1
|
||||||
|
],
|
||||||
|
"timestamp": "2025-04-01 17:43:22"
|
||||||
|
}
|
@ -14,8 +14,8 @@ def visualize_solution(row_boundaries, col_boundaries, car_paths_coords, W, H, r
|
|||||||
|
|
||||||
# 设置英文标题和标签
|
# 设置英文标题和标签
|
||||||
# ax.set_title("Monte Carlo", fontsize=12)
|
# ax.set_title("Monte Carlo", fontsize=12)
|
||||||
ax.set_title("Greedy", fontsize=12)
|
# ax.set_title("Greedy", fontsize=12)
|
||||||
# ax.set_title("Enumeration-Genetic Algorithm", fontsize=12)
|
ax.set_title("Enumeration-Genetic Algorithm", fontsize=12)
|
||||||
# ax.set_title("DQN fine-tuning", fontsize=12)
|
# ax.set_title("DQN fine-tuning", fontsize=12)
|
||||||
|
|
||||||
ax.set_xlabel("Region Width", fontsize=10)
|
ax.set_xlabel("Region Width", fontsize=10)
|
||||||
@ -200,7 +200,7 @@ if __name__ == "__main__":
|
|||||||
# 需要修改的超参数
|
# 需要修改的超参数
|
||||||
# ---------------------------
|
# ---------------------------
|
||||||
params_file = 'params_50_50_3'
|
params_file = 'params_50_50_3'
|
||||||
solution_file = r'solutions\greedy_params_50_50_3.json'
|
solution_file = r'solutions\trav_ga_params_50_50_3_parallel.json'
|
||||||
|
|
||||||
with open(params_file + '.yml', 'r', encoding='utf-8') as file:
|
with open(params_file + '.yml', 'r', encoding='utf-8') as file:
|
||||||
params = yaml.safe_load(file)
|
params = yaml.safe_load(file)
|
||||||
|
Loading…
Reference in New Issue
Block a user