修改100_100_6的dqn场景

This commit is contained in:
weixin_46229132 2025-04-04 10:59:31 +08:00
parent 23aafc2998
commit 87ee65087f
6 changed files with 175 additions and 84 deletions

View File

@ -47,7 +47,7 @@ def save_best_solution(info_lt):
# 读取已有的最优解 # 读取已有的最优解
try: try:
with open('solutions/dqn_params_50_50_3.json', 'r') as f: with open('solutions/dqn_params_100_100_6.json', 'r') as f:
saved_solution = json.load(f) saved_solution = json.load(f)
saved_time = saved_solution['best_time'] saved_time = saved_solution['best_time']
except FileNotFoundError: except FileNotFoundError:
@ -63,7 +63,7 @@ def save_best_solution(info_lt):
'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S') 'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
} }
with open('solutions/dqn_params_50_50_3.json', 'w') as f: with open('solutions/dqn_params_100_100_6.json', 'w') as f:
json.dump(best_solution, f, indent=4) json.dump(best_solution, f, indent=4)
print(f"发现新的最优解!时间: {best_info['best_time']}") print(f"发现新的最优解!时间: {best_info['best_time']}")

View File

@ -18,13 +18,13 @@ class PartitionEnv(gym.Env):
############################## ##############################
# 可能需要手动修改的超参数 # 可能需要手动修改的超参数
############################## ##############################
self.params = 'params_50_50_3' self.params = 'params_100_100_6'
self.ORI_ROW_CUTS = [0, 0.2, 0.4, 0.7, 1] self.ORI_ROW_CUTS = [0, 0.2, 0.5, 0.7, 1]
self.ORI_COL_CUTS = [0, 0.5, 1] self.ORI_COL_CUTS = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]
self.CUT_NUM = 4 self.CUT_NUM = 12
self.BASE_LINE = 9051.16 self.BASE_LINE = 19616.68
self.MAX_ADJUST_STEP = 50 self.MAX_ADJUST_STEP = 80
self.ADJUST_THRESHOLD = 0.1 # self.ADJUST_THRESHOLD = 0.1
# self.mTSP_STEPS = 10000 # self.mTSP_STEPS = 10000
# 切分位置+/-0.01 # 切分位置+/-0.01
@ -73,26 +73,41 @@ class PartitionEnv(gym.Env):
return state return state
def step(self, action): def step(self, action):
if action == 1: # if action == 1:
self.row_cuts[1] += 0.01 # self.row_cuts[1] += 0.01
elif action == 2: # elif action == 2:
self.row_cuts[1] -= 0.01 # self.row_cuts[1] -= 0.01
elif action == 3: # elif action == 3:
self.row_cuts[2] += 0.01 # self.row_cuts[2] += 0.01
elif action == 4: # elif action == 4:
self.row_cuts[2] -= 0.01 # self.row_cuts[2] -= 0.01
elif action == 5: # elif action == 5:
self.row_cuts[3] += 0.01 # self.row_cuts[3] += 0.01
elif action == 6: # elif action == 6:
self.row_cuts[3] -= 0.01 # self.row_cuts[3] -= 0.01
elif action == 7: # elif action == 7:
self.col_cuts[1] += 0.01 # self.col_cuts[1] += 0.01
elif action == 8: # elif action == 8:
self.col_cuts[1] -= 0.01 # self.col_cuts[1] -= 0.01
elif action == 0: # elif action == 0:
# pass
cut_index, signal = (action + 1) // 2, (action + 1) % 2
if action == 0:
pass pass
elif cut_index <= 3:
if signal == 0:
self.row_cuts[cut_index] += 0.01
else:
self.row_cuts[cut_index] -= 0.01
else:
if signal == 0:
self.col_cuts[cut_index-3] += 0.01
else:
self.col_cuts[cut_index-3] -= 0.01
if self.row_cuts[0] < self.row_cuts[1] < self.row_cuts[2] < self.row_cuts[3] < self.row_cuts[4] and self.col_cuts[0] < self.col_cuts[1] < self.col_cuts[2]: # 检查row_cuts和col_cuts是否按升序排列
if (all(self.row_cuts[i] < self.row_cuts[i+1] for i in range(len(self.row_cuts)-1)) and
all(self.col_cuts[i] < self.col_cuts[i+1] for i in range(len(self.col_cuts)-1))):
# 调整是合法的,验证分区情况是否满足条件 # 调整是合法的,验证分区情况是否满足条件
rectangles = self.if_valid_partition() rectangles = self.if_valid_partition()
@ -107,7 +122,8 @@ class PartitionEnv(gym.Env):
# else: # else:
# # 根据最佳路径计算当前时间 # # 根据最佳路径计算当前时间
# best_time = self.get_best_time(self.best_path, rectangles) # best_time = self.get_best_time(self.best_path, rectangles)
self.best_path = [0, 1, 3, 5, 9, 7, 8, 10, 2, 4, 6, 0] self.best_path = [0, 17, 10, 9, 8, 7, 6, 5, 0, 28, 29, 30, 19, 20, 18, 16, 43, 27, 40, 39, 38, 37,
36, 26, 45, 14, 13, 12, 11, 22, 21, 23, 24, 41, 44, 25, 34, 35, 33, 32, 31, 42, 15, 4, 3, 2, 1, 0]
best_time = self.get_best_time(self.best_path, rectangles) best_time = self.get_best_time(self.best_path, rectangles)
else: else:
@ -239,7 +255,7 @@ class PartitionEnv(gym.Env):
# 使用tanh归一化确保time_diff=0时normalized_diff=0 # 使用tanh归一化确保time_diff=0时normalized_diff=0
# tanh在变量值为2时就非常接近1了。最大的time_diff为400 # tanh在变量值为2时就非常接近1了。最大的time_diff为400
normalized_diff = np.tanh(time_diff / 200) # 20是缩放因子可调整 normalized_diff = np.tanh(time_diff / 5000) # 20是缩放因子可调整
# 计算轮次权重(折扣因子) # 计算轮次权重(折扣因子)
# step_weight = 1 / (1 + np.exp(-self.adjust_step/10)) # step_weight = 1 / (1 + np.exp(-self.adjust_step/10))

View File

@ -13,8 +13,7 @@ print('state:', state)
# action_series = [1] * 30 # action_series = [1] * 30
# action_series = [[0.2], [0.4], [0.7], [0.5]] # action_series = [[0.2], [0.4], [0.7], [0.5]]
# action_series = [[-0.08], [-0.08], [0], [0]] # action_series = [[-0.08], [-0.08], [0], [0]]
action_series = [3, 5, 3, 5, 1, 1, 3, 5, 1, 5, 3, 1, 1, 1, 1, 1, 2, 1, 2, 1, 2, 1, action_series = [0, 0, 3, 4, 24, 20]
2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1]
for i in range(100): for i in range(100):
action = action_series[i] action = action_series[i]

View File

@ -0,0 +1,74 @@
{
"best_time": 19557.574055662244,
"row_cuts": [
0,
0.2,
0.5,
0.7,
1
],
"col_cuts": [
0,
0.1,
0.19,
0.3,
0.4,
0.5,
0.6,
0.7,
0.8,
0.9,
1
],
"best_path": [
0,
17,
10,
9,
8,
7,
6,
5,
0,
28,
29,
30,
19,
20,
18,
16,
43,
27,
40,
39,
38,
37,
36,
26,
45,
14,
13,
12,
11,
22,
21,
23,
24,
41,
25,
44,
34,
35,
33,
32,
31,
42,
15,
4,
3,
2,
1,
0
],
"timestamp": "2025-04-04 10:47:47"
}

View File

@ -1,5 +1,12 @@
{ {
"row_boundaries": [ "row_boundaries": [
0.0,
0.2,
0.5,
0.7,
1.0
],
"col_boundaries": [
0.0, 0.0,
0.1, 0.1,
0.2, 0.2,
@ -12,63 +19,58 @@
0.9, 0.9,
1.0 1.0
], ],
"col_boundaries": [
0.0,
0.2,
0.4,
0.7,
1.0
],
"car_paths": [ "car_paths": [
[ [
17, 16,
5,
4,
0,
1,
2,
6,
10,
14
],
[
18,
13,
9, 9,
8, 8,
12,
16,
20,
32,
21
],
[
22,
26,
30,
34,
39,
35,
31,
27,
19
],
[
25,
24,
28,
36,
33,
37,
38,
29
],
[
15,
23,
7, 7,
6,
5,
4
],
[
27,
28,
29,
18,
19,
17,
15
],
[
26,
39,
38,
37,
36,
35,
25
],
[
13,
12,
11, 11,
3 10,
21,
20,
22,
23
],
[
24,
33,
34,
32,
31,
30
],
[
14,
3,
2,
1,
0
] ]
] ]
} }

View File

@ -200,7 +200,7 @@ if __name__ == "__main__":
# 需要修改的超参数 # 需要修改的超参数
# --------------------------- # ---------------------------
params_file = 'params_100_100_6' params_file = 'params_100_100_6'
solution_file = r'solutions\greedy_params_100_100_6.json' solution_file = r'solutions\trav_ga_params_100_100_6_parallel.json'
with open(params_file + '.yml', 'r', encoding='utf-8') as file: with open(params_file + '.yml', 'r', encoding='utf-8') as file:
params = yaml.safe_load(file) params = yaml.safe_load(file)