dqn 100_100_6

This commit is contained in:
weixin_46229132 2025-04-05 11:06:08 +08:00
parent c6c7cb47f1
commit 90ad3e829d
3 changed files with 70 additions and 13 deletions

View File

@ -19,10 +19,10 @@ class PartitionEnv(gym.Env):
# 可能需要手动修改的超参数 # 可能需要手动修改的超参数
############################## ##############################
self.params = 'params_100_100_6' self.params = 'params_100_100_6'
self.ORI_ROW_CUTS = [0, 0.2, 0.5, 0.7, 1] self.ORI_ROW_CUTS = [0, 0.2, 0.4, 0.6, 0.8, 1]
self.ORI_COL_CUTS = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1] self.ORI_COL_CUTS = [0, 0.2, 0.4, 0.5, 0.7, 0.8, 1]
self.CUT_NUM = 12 self.CUT_NUM = 9
self.BASE_LINE = 19616.68 self.BASE_LINE = 19757.42
self.MAX_ADJUST_STEP = 80 self.MAX_ADJUST_STEP = 80
# self.ADJUST_THRESHOLD = 0.1 # self.ADJUST_THRESHOLD = 0.1
# self.mTSP_STEPS = 10000 # self.mTSP_STEPS = 10000
@ -94,16 +94,16 @@ class PartitionEnv(gym.Env):
cut_index, signal = (action + 1) // 2, (action + 1) % 2 cut_index, signal = (action + 1) // 2, (action + 1) % 2
if action == 0: if action == 0:
pass pass
elif cut_index <= 3: elif cut_index <= 4:
if signal == 0: if signal == 0:
self.row_cuts[cut_index] += 0.01 self.row_cuts[cut_index] += 0.01
else: else:
self.row_cuts[cut_index] -= 0.01 self.row_cuts[cut_index] -= 0.01
else: else:
if signal == 0: if signal == 0:
self.col_cuts[cut_index-3] += 0.01 self.col_cuts[cut_index-4] += 0.01
else: else:
self.col_cuts[cut_index-3] -= 0.01 self.col_cuts[cut_index-4] -= 0.01
# 检查row_cuts和col_cuts是否按升序排列 # 检查row_cuts和col_cuts是否按升序排列
if (all(self.row_cuts[i] < self.row_cuts[i+1] for i in range(len(self.row_cuts)-1)) and if (all(self.row_cuts[i] < self.row_cuts[i+1] for i in range(len(self.row_cuts)-1)) and
@ -122,8 +122,8 @@ class PartitionEnv(gym.Env):
# else: # else:
# # 根据最佳路径计算当前时间 # # 根据最佳路径计算当前时间
# best_time = self.get_best_time(self.best_path, rectangles) # best_time = self.get_best_time(self.best_path, rectangles)
self.best_path = [0, 17, 10, 9, 8, 7, 6, 5, 0, 28, 29, 30, 19, 20, 18, 16, 43, 27, 40, 39, 38, 37, self.best_path = [33, 30, 29, 28, 27, 21, 15, 0, 13, 7, 1, 2, 31, 14, 8, 3, 4,
36, 26, 45, 14, 13, 12, 11, 22, 21, 23, 24, 41, 44, 25, 34, 35, 33, 32, 31, 42, 15, 4, 3, 2, 1, 0] 10, 32, 23, 22, 24, 18, 17, 16, 35, 9, 12, 6, 5, 11, 34, 20, 25, 26, 19, 0]
best_time = self.get_best_time(self.best_path, rectangles) best_time = self.get_best_time(self.best_path, rectangles)
else: else:
@ -257,9 +257,6 @@ class PartitionEnv(gym.Env):
# tanh在变量值为2时就非常接近1了。最大的time_diff为400 # tanh在变量值为2时就非常接近1了。最大的time_diff为400
normalized_diff = np.tanh(time_diff / 5000) # 20是缩放因子可调整 normalized_diff = np.tanh(time_diff / 5000) # 20是缩放因子可调整
# 计算轮次权重(折扣因子)
# step_weight = 1 / (1 + np.exp(-self.adjust_step/10))
# 计算最终奖励 # 计算最终奖励
reward = normalized_diff reward = normalized_diff
# * step_weight # 10是缩放因子 # * step_weight # 10是缩放因子

View File

@ -13,7 +13,7 @@ print('state:', state)
# action_series = [1] * 30 # action_series = [1] * 30
# action_series = [[0.2], [0.4], [0.7], [0.5]] # action_series = [[0.2], [0.4], [0.7], [0.5]]
# action_series = [[-0.08], [-0.08], [0], [0]] # action_series = [[-0.08], [-0.08], [0], [0]]
action_series = [0, 0, 3, 4, 24, 20] action_series = list(range(19))
for i in range(100): for i in range(100):
action = action_series[i] action = action_series[i]

View File

@ -0,0 +1,60 @@
{
"best_time": 19643.795059416032,
"row_cuts": [
0,
0.2,
0.4,
0.6,
0.78,
1
],
"col_cuts": [
0,
0.2,
0.4,
0.5,
0.7,
0.8,
1
],
"best_path": [
33,
30,
29,
28,
27,
21,
15,
0,
13,
7,
1,
2,
31,
14,
8,
3,
4,
10,
32,
23,
22,
24,
18,
17,
16,
35,
9,
12,
6,
5,
11,
34,
20,
25,
26,
19,
0
],
"timestamp": "2025-04-05 11:03:20"
}