修改dqn bug

This commit is contained in:
weixin_46229132 2025-04-01 20:45:13 +08:00
parent db04a87ffd
commit 981681c1bd
6 changed files with 32 additions and 27 deletions

View File

@ -16,7 +16,7 @@ class GA(object):
self.location = data self.location = data
self.to_process_idx = to_process_idx self.to_process_idx = to_process_idx
self.rectangles = rectangles self.rectangles = rectangles
self.epochs = 1500 self.epochs = 300
self.ga_choose_ratio = 0.2 self.ga_choose_ratio = 0.2
self.mutate_ratio = 0.05 self.mutate_ratio = 0.05
# fruits中存每一个个体是下标的list # fruits中存每一个个体是下标的list
@ -320,6 +320,7 @@ class GA(object):
best_length = 1.0 / best_score best_length = 1.0 / best_score
# print(f"Epoch {i:3}: {best_length:.3f}") # print(f"Epoch {i:3}: {best_length:.3f}")
# print(1.0 / best_score) # print(1.0 / best_score)
# print(i)
return tmp_best_one, 1.0 / best_score return tmp_best_one, 1.0 / best_score
# if __name__ == '__main__': # if __name__ == '__main__':

View File

@ -20,7 +20,7 @@ best_col_boundaries = None
# --------------------------- # ---------------------------
R = 3 R = 3
C = 3 C = 3
params_file = 'params2' params_file = 'params_50_50_3'
with open(params_file + '.yml', 'r', encoding='utf-8') as file: with open(params_file + '.yml', 'r', encoding='utf-8') as file:

View File

@ -89,7 +89,7 @@ class PartitionEnv(gym.Env):
self.col_cuts[1] += 0.01 self.col_cuts[1] += 0.01
elif action == 8: elif action == 8:
self.col_cuts[1] -= 0.01 self.col_cuts[1] -= 0.01
elif action == 9: elif action == 0:
pass pass
if self.row_cuts[0] < self.row_cuts[1] < self.row_cuts[2] < self.row_cuts[3] < self.row_cuts[4] and self.col_cuts[0] < self.col_cuts[1] < self.col_cuts[2]: if self.row_cuts[0] < self.row_cuts[1] < self.row_cuts[2] < self.row_cuts[3] < self.row_cuts[4] and self.col_cuts[0] < self.col_cuts[1] < self.col_cuts[2]:
@ -100,12 +100,14 @@ class PartitionEnv(gym.Env):
# 不满足条件,时间给一个很大的值 # 不满足条件,时间给一个很大的值
best_time = self.BASE_LINE * 2 best_time = self.BASE_LINE * 2
else: else:
# 满足条件,继续进行路径规划 # # 满足条件,继续进行路径规划
# 每隔10步计算一次路径第一次也需要计算路径记录最佳路径 # # 每隔10步计算一次路径第一次也需要计算路径记录最佳路径
if self.adjust_step % 10 == 0 or self.adjust_step == 1 or self.best_path is None: # if self.adjust_step % 10 == 0 or self.adjust_step == 1 or self.best_path is None:
best_time, self.best_path = self.ga_solver(rectangles) # best_time, self.best_path = self.ga_solver(rectangles)
else: # else:
# 根据最佳路径计算当前时间 # # 根据最佳路径计算当前时间
# best_time = self.get_best_time(self.best_path, rectangles)
self.best_path = [0, 1, 3, 5, 9, 7, 8, 10, 2, 4, 6, 0]
best_time = self.get_best_time(self.best_path, rectangles) best_time = self.get_best_time(self.best_path, rectangles)
else: else:
@ -240,10 +242,11 @@ class PartitionEnv(gym.Env):
normalized_diff = np.tanh(time_diff / 200) # 20是缩放因子可调整 normalized_diff = np.tanh(time_diff / 200) # 20是缩放因子可调整
# 计算轮次权重(折扣因子) # 计算轮次权重(折扣因子)
step_weight = 1 / (1 + np.exp(-self.adjust_step/10)) # step_weight = 1 / (1 + np.exp(-self.adjust_step/10))
# 计算最终奖励 # 计算最终奖励
reward = normalized_diff * step_weight # 10是缩放因子 reward = normalized_diff
# * step_weight # 10是缩放因子
return reward return reward

View File

@ -9,8 +9,8 @@ state = env.reset()
print('state:', state) print('state:', state)
# action_series = [[0.67], [0], [0], [0], [0.7]] # action_series = [[0.67], [0], [0], [0], [0.7]]
action_series = [1, 1, 1, 1, 1, 1] action_series = [3, 3, 3, 5, 5, 1, 1, 1, 0, 0, 0]
action_series = [1] * 30 # action_series = [1] * 30
# action_series = [[0.2], [0.4], [0.7], [0.5]] # action_series = [[0.2], [0.4], [0.7], [0.5]]
# action_series = [[-0.08], [-0.08], [0], [0]] # action_series = [[-0.08], [-0.08], [0], [0]]

View File

@ -1,10 +1,10 @@
{ {
"best_time": 9051.162633521315, "best_time": 8848.626166217664,
"row_cuts": [ "row_cuts": [
0, 0,
0.21000000000000002, 0.2700000000000001,
0.4, 0.4700000000000001,
0.7, 0.76,
1 1
], ],
"col_cuts": [ "col_cuts": [
@ -13,17 +13,18 @@
1 1
], ],
"best_path": [ "best_path": [
0,
1,
3,
5,
9,
7, 7,
8, 8,
0, 10,
6,
2, 2,
4, 4,
10, 6,
9, 0
5,
3,
1
], ],
"timestamp": "2025-04-01 17:43:22" "timestamp": "2025-04-01 20:05:51"
} }

View File

@ -200,7 +200,7 @@ if __name__ == "__main__":
# 需要修改的超参数 # 需要修改的超参数
# --------------------------- # ---------------------------
params_file = 'params_50_50_3' params_file = 'params_50_50_3'
solution_file = r'solutions\trav_ga_params_50_50_3_parallel.json' solution_file = r'solutions\finetune_params_50_50_3.json'
with open(params_file + '.yml', 'r', encoding='utf-8') as file: with open(params_file + '.yml', 'r', encoding='utf-8') as file:
params = yaml.safe_load(file) params = yaml.safe_load(file)