修改dqn bug
This commit is contained in:
parent
db04a87ffd
commit
981681c1bd
3
GA/ga.py
3
GA/ga.py
@ -16,7 +16,7 @@ class GA(object):
|
|||||||
self.location = data
|
self.location = data
|
||||||
self.to_process_idx = to_process_idx
|
self.to_process_idx = to_process_idx
|
||||||
self.rectangles = rectangles
|
self.rectangles = rectangles
|
||||||
self.epochs = 1500
|
self.epochs = 300
|
||||||
self.ga_choose_ratio = 0.2
|
self.ga_choose_ratio = 0.2
|
||||||
self.mutate_ratio = 0.05
|
self.mutate_ratio = 0.05
|
||||||
# fruits中存每一个个体是下标的list
|
# fruits中存每一个个体是下标的list
|
||||||
@ -320,6 +320,7 @@ class GA(object):
|
|||||||
best_length = 1.0 / best_score
|
best_length = 1.0 / best_score
|
||||||
# print(f"Epoch {i:3}: {best_length:.3f}")
|
# print(f"Epoch {i:3}: {best_length:.3f}")
|
||||||
# print(1.0 / best_score)
|
# print(1.0 / best_score)
|
||||||
|
# print(i)
|
||||||
return tmp_best_one, 1.0 / best_score
|
return tmp_best_one, 1.0 / best_score
|
||||||
|
|
||||||
# if __name__ == '__main__':
|
# if __name__ == '__main__':
|
||||||
|
@ -20,7 +20,7 @@ best_col_boundaries = None
|
|||||||
# ---------------------------
|
# ---------------------------
|
||||||
R = 3
|
R = 3
|
||||||
C = 3
|
C = 3
|
||||||
params_file = 'params2'
|
params_file = 'params_50_50_3'
|
||||||
|
|
||||||
|
|
||||||
with open(params_file + '.yml', 'r', encoding='utf-8') as file:
|
with open(params_file + '.yml', 'r', encoding='utf-8') as file:
|
||||||
|
@ -89,7 +89,7 @@ class PartitionEnv(gym.Env):
|
|||||||
self.col_cuts[1] += 0.01
|
self.col_cuts[1] += 0.01
|
||||||
elif action == 8:
|
elif action == 8:
|
||||||
self.col_cuts[1] -= 0.01
|
self.col_cuts[1] -= 0.01
|
||||||
elif action == 9:
|
elif action == 0:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if self.row_cuts[0] < self.row_cuts[1] < self.row_cuts[2] < self.row_cuts[3] < self.row_cuts[4] and self.col_cuts[0] < self.col_cuts[1] < self.col_cuts[2]:
|
if self.row_cuts[0] < self.row_cuts[1] < self.row_cuts[2] < self.row_cuts[3] < self.row_cuts[4] and self.col_cuts[0] < self.col_cuts[1] < self.col_cuts[2]:
|
||||||
@ -100,13 +100,15 @@ class PartitionEnv(gym.Env):
|
|||||||
# 不满足条件,时间给一个很大的值
|
# 不满足条件,时间给一个很大的值
|
||||||
best_time = self.BASE_LINE * 2
|
best_time = self.BASE_LINE * 2
|
||||||
else:
|
else:
|
||||||
# 满足条件,继续进行路径规划
|
# # 满足条件,继续进行路径规划
|
||||||
# 每隔10步计算一次路径,第一次也需要计算路径,记录最佳路径
|
# # 每隔10步计算一次路径,第一次也需要计算路径,记录最佳路径
|
||||||
if self.adjust_step % 10 == 0 or self.adjust_step == 1 or self.best_path is None:
|
# if self.adjust_step % 10 == 0 or self.adjust_step == 1 or self.best_path is None:
|
||||||
best_time, self.best_path = self.ga_solver(rectangles)
|
# best_time, self.best_path = self.ga_solver(rectangles)
|
||||||
else:
|
# else:
|
||||||
# 根据最佳路径计算当前时间
|
# # 根据最佳路径计算当前时间
|
||||||
best_time = self.get_best_time(self.best_path, rectangles)
|
# best_time = self.get_best_time(self.best_path, rectangles)
|
||||||
|
self.best_path = [0, 1, 3, 5, 9, 7, 8, 10, 2, 4, 6, 0]
|
||||||
|
best_time = self.get_best_time(self.best_path, rectangles)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# 调整不合法,时间给一个很大的值
|
# 调整不合法,时间给一个很大的值
|
||||||
@ -240,10 +242,11 @@ class PartitionEnv(gym.Env):
|
|||||||
normalized_diff = np.tanh(time_diff / 200) # 20是缩放因子,可调整
|
normalized_diff = np.tanh(time_diff / 200) # 20是缩放因子,可调整
|
||||||
|
|
||||||
# 计算轮次权重(折扣因子)
|
# 计算轮次权重(折扣因子)
|
||||||
step_weight = 1 / (1 + np.exp(-self.adjust_step/10))
|
# step_weight = 1 / (1 + np.exp(-self.adjust_step/10))
|
||||||
|
|
||||||
# 计算最终奖励
|
# 计算最终奖励
|
||||||
reward = normalized_diff * step_weight # 10是缩放因子
|
reward = normalized_diff
|
||||||
|
# * step_weight # 10是缩放因子
|
||||||
|
|
||||||
return reward
|
return reward
|
||||||
|
|
||||||
|
@ -9,8 +9,8 @@ state = env.reset()
|
|||||||
print('state:', state)
|
print('state:', state)
|
||||||
|
|
||||||
# action_series = [[0.67], [0], [0], [0], [0.7]]
|
# action_series = [[0.67], [0], [0], [0], [0.7]]
|
||||||
action_series = [1, 1, 1, 1, 1, 1]
|
action_series = [3, 3, 3, 5, 5, 1, 1, 1, 0, 0, 0]
|
||||||
action_series = [1] * 30
|
# action_series = [1] * 30
|
||||||
# action_series = [[0.2], [0.4], [0.7], [0.5]]
|
# action_series = [[0.2], [0.4], [0.7], [0.5]]
|
||||||
# action_series = [[-0.08], [-0.08], [0], [0]]
|
# action_series = [[-0.08], [-0.08], [0], [0]]
|
||||||
|
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
{
|
{
|
||||||
"best_time": 9051.162633521315,
|
"best_time": 8848.626166217664,
|
||||||
"row_cuts": [
|
"row_cuts": [
|
||||||
0,
|
0,
|
||||||
0.21000000000000002,
|
0.2700000000000001,
|
||||||
0.4,
|
0.4700000000000001,
|
||||||
0.7,
|
0.76,
|
||||||
1
|
1
|
||||||
],
|
],
|
||||||
"col_cuts": [
|
"col_cuts": [
|
||||||
@ -13,17 +13,18 @@
|
|||||||
1
|
1
|
||||||
],
|
],
|
||||||
"best_path": [
|
"best_path": [
|
||||||
|
0,
|
||||||
|
1,
|
||||||
|
3,
|
||||||
|
5,
|
||||||
|
9,
|
||||||
7,
|
7,
|
||||||
8,
|
8,
|
||||||
0,
|
10,
|
||||||
6,
|
|
||||||
2,
|
2,
|
||||||
4,
|
4,
|
||||||
10,
|
6,
|
||||||
9,
|
0
|
||||||
5,
|
|
||||||
3,
|
|
||||||
1
|
|
||||||
],
|
],
|
||||||
"timestamp": "2025-04-01 17:43:22"
|
"timestamp": "2025-04-01 20:05:51"
|
||||||
}
|
}
|
@ -200,7 +200,7 @@ if __name__ == "__main__":
|
|||||||
# 需要修改的超参数
|
# 需要修改的超参数
|
||||||
# ---------------------------
|
# ---------------------------
|
||||||
params_file = 'params_50_50_3'
|
params_file = 'params_50_50_3'
|
||||||
solution_file = r'solutions\trav_ga_params_50_50_3_parallel.json'
|
solution_file = r'solutions\finetune_params_50_50_3.json'
|
||||||
|
|
||||||
with open(params_file + '.yml', 'r', encoding='utf-8') as file:
|
with open(params_file + '.yml', 'r', encoding='utf-8') as file:
|
||||||
params = yaml.safe_load(file)
|
params = yaml.safe_load(file)
|
||||||
|
Loading…
Reference in New Issue
Block a user