可视化reward

This commit is contained in:
weixin_46229132 2025-04-08 15:49:22 +08:00
parent 90ad3e829d
commit d64ec83042
7 changed files with 67 additions and 100 deletions

View File

@ -42,14 +42,14 @@ class GA(object):
# print("Lens:", len(result), len(result[0]))
return result
def greedy_init(self, dis_mat, num_total, num_city):
def greedy_init(self):
start_index = 0
result = []
for i in range(num_total):
rest = [x for x in range(0, num_city)]
for i in range(self.num_total):
rest = [x for x in range(0, self.num_city)]
# 所有起始点都已经生成了
if start_index >= num_city:
start_index = np.random.randint(0, num_city)
if start_index >= self.num_city:
start_index = np.random.randint(0, self.num_city)
result.append(result[start_index].copy())
continue
current = start_index
@ -61,8 +61,8 @@ class GA(object):
tmp_choose = -1
for x in rest:
# print("---", current, x, dis_mat[current][x])
if dis_mat[current][x] < tmp_min:
tmp_min = dis_mat[current][x]
if self.dis_mat[current][x] < tmp_min:
tmp_min = self.dis_mat[current][x]
tmp_choose = x
if tmp_choose == -1: # 此种情况仅可能发生在剩的都是基地点
tmp_choose = rest[0]

View File

@ -19,11 +19,11 @@ class PartitionEnv(gym.Env):
# 可能需要手动修改的超参数
##############################
self.params = 'params_100_100_6'
self.ORI_ROW_CUTS = [0, 0.2, 0.4, 0.6, 0.8, 1]
self.ORI_ROW_CUTS = [0, 0.28, 0.43, 0.62, 0.77, 1]
self.ORI_COL_CUTS = [0, 0.2, 0.4, 0.5, 0.7, 0.8, 1]
self.CUT_NUM = 9
self.BASE_LINE = 19757.42
self.MAX_ADJUST_STEP = 80
self.CUT_NUM = 5
self.BASE_LINE = 19376.06
self.MAX_ADJUST_STEP = 50
# self.ADJUST_THRESHOLD = 0.1
# self.mTSP_STEPS = 10000
@ -31,7 +31,7 @@ class PartitionEnv(gym.Env):
self.action_space = spaces.Discrete(self.CUT_NUM*2 + 1)
# 定义观察空间为8维向量
self.observation_space = spaces.Box(
low=0.0, high=1.0, shape=(self.CUT_NUM + 4,), dtype=np.float32)
low=0.0, high=1.0, shape=(len(self.ORI_ROW_CUTS)+len(self.ORI_COL_CUTS),), dtype=np.float32)
self.row_cuts = self.ORI_ROW_CUTS[:]
self.col_cuts = self.ORI_COL_CUTS[:]
@ -94,16 +94,16 @@ class PartitionEnv(gym.Env):
cut_index, signal = (action + 1) // 2, (action + 1) % 2
if action == 0:
pass
elif cut_index <= 4:
elif cut_index <= 5:
if signal == 0:
self.row_cuts[cut_index] += 0.01
self.col_cuts[cut_index] += 0.005
else:
self.row_cuts[cut_index] -= 0.01
self.col_cuts[cut_index] -= 0.005
else:
if signal == 0:
self.col_cuts[cut_index-4] += 0.01
self.col_cuts[cut_index-4] += 0.005
else:
self.col_cuts[cut_index-4] -= 0.01
self.col_cuts[cut_index-4] -= 0.005
# 检查row_cuts和col_cuts是否按升序排列
if (all(self.row_cuts[i] < self.row_cuts[i+1] for i in range(len(self.row_cuts)-1)) and

View File

@ -13,7 +13,7 @@ print('state:', state)
# action_series = [1] * 30
# action_series = [[0.2], [0.4], [0.7], [0.5]]
# action_series = [[-0.08], [-0.08], [0], [0]]
action_series = list(range(19))
action_series = list(range(11))
for i in range(100):
action = action_series[i]

40
plot_dqn_training.py Normal file
View File

@ -0,0 +1,40 @@
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
def plot_dqn_training_curve(csv_file):
# 读取CSV文件
data = pd.read_csv(csv_file)
# 检查是否包含所需列
if 'Step' not in data.columns or 'Value' not in data.columns:
raise ValueError("CSV文件必须包含'step''value'")
# 提取数据
steps = data['Step']
rewards = data['Value']
# Plot the curve
plt.plot(steps, rewards, label='Reward', color='blue')
# Set title and axis labels
# plt.title("DQN Training Curve", fontsize=16)
plt.xlabel("Training Steps (w)", fontsize=14)
plt.ylabel("Reward", fontsize=14)
# Adjust x-axis ticks dynamically based on step range
step_min, step_max = steps.min(), steps.max()
tick_interval = (step_max - step_min) // 10 # Divide into 10 intervals
ticks = np.arange(step_min, step_max + tick_interval, tick_interval)
plt.xticks(ticks=ticks, labels=[f"{x//10000}w" for x in ticks])
# Add grid and legend
plt.grid(True, linestyle='--', alpha=0.7)
# Show the plot
plt.show()
if __name__ == "__main__":
# 替换为你的CSV文件路径
csv_file = r"runs\DQN-PartEnv_S0_ 2025-04-02 20_13\DQN-PartEnv_S0_ 2025-04-02 20_13.csv"
plot_dqn_training_curve(csv_file)

View File

@ -1,11 +1,11 @@
{
"best_time": 19643.795059416032,
"best_time": 19376.05694186515,
"row_cuts": [
0,
0.2,
0.4,
0.6,
0.78,
0.2800000000000001,
0.43000000000000005,
0.62,
0.77,
1
],
"col_cuts": [
@ -56,5 +56,5 @@
19,
0
],
"timestamp": "2025-04-05 11:03:20"
"timestamp": "2025-04-06 09:10:53"
}

View File

@ -1,73 +0,0 @@
{
"row_boundaries": [
0.0,
0.2,
0.4,
0.6000000000000001,
0.8,
1.0
],
"col_boundaries": [
0,
0.12533333333333332,
0.25066666666666665,
0.376,
0.5013333333333333,
0.6266666666666666,
0.7519999999999999,
0.8773333333333332,
1
],
"car_paths": [
[
0,
1,
2,
3,
4,
5,
6,
7
],
[
8,
9,
10,
11,
12,
13,
14,
15
],
[
16,
17,
18,
19,
20,
21,
22,
23
],
[
24,
25,
26,
27,
28,
29,
30,
31
],
[
32,
33,
34,
35,
36,
37,
38,
39
]
]
}

View File

@ -14,9 +14,9 @@ def visualize_solution(row_boundaries, col_boundaries, car_paths_coords, W, H, r
# 设置英文标题和标签
# ax.set_title("Monte Carlo", fontsize=12)
ax.set_title("Greedy", fontsize=12)
# ax.set_title("Greedy", fontsize=12)
# ax.set_title("Enumeration-Genetic Algorithm", fontsize=12)
# ax.set_title("DQN fine-tuning", fontsize=12)
ax.set_title("DQN fine-tuning", fontsize=12)
ax.set_xlabel("Region Width", fontsize=10)
ax.set_ylabel("Region Height", fontsize=10)
@ -200,7 +200,7 @@ if __name__ == "__main__":
# 需要修改的超参数
# ---------------------------
params_file = 'params_100_100_6'
solution_file = r'solutions\greedy_params_100_100_6.json'
solution_file = r'solutions\finetune_params_100_100_6.json'
with open(params_file + '.yml', 'r', encoding='utf-8') as file:
params = yaml.safe_load(file)