可视化reward
This commit is contained in:
parent
90ad3e829d
commit
d64ec83042
14
GA/ga.py
14
GA/ga.py
@ -42,14 +42,14 @@ class GA(object):
|
||||
# print("Lens:", len(result), len(result[0]))
|
||||
return result
|
||||
|
||||
def greedy_init(self, dis_mat, num_total, num_city):
|
||||
def greedy_init(self):
|
||||
start_index = 0
|
||||
result = []
|
||||
for i in range(num_total):
|
||||
rest = [x for x in range(0, num_city)]
|
||||
for i in range(self.num_total):
|
||||
rest = [x for x in range(0, self.num_city)]
|
||||
# 所有起始点都已经生成了
|
||||
if start_index >= num_city:
|
||||
start_index = np.random.randint(0, num_city)
|
||||
if start_index >= self.num_city:
|
||||
start_index = np.random.randint(0, self.num_city)
|
||||
result.append(result[start_index].copy())
|
||||
continue
|
||||
current = start_index
|
||||
@ -61,8 +61,8 @@ class GA(object):
|
||||
tmp_choose = -1
|
||||
for x in rest:
|
||||
# print("---", current, x, dis_mat[current][x])
|
||||
if dis_mat[current][x] < tmp_min:
|
||||
tmp_min = dis_mat[current][x]
|
||||
if self.dis_mat[current][x] < tmp_min:
|
||||
tmp_min = self.dis_mat[current][x]
|
||||
tmp_choose = x
|
||||
if tmp_choose == -1: # 此种情况仅可能发生在剩的都是基地点
|
||||
tmp_choose = rest[0]
|
||||
|
@ -19,11 +19,11 @@ class PartitionEnv(gym.Env):
|
||||
# 可能需要手动修改的超参数
|
||||
##############################
|
||||
self.params = 'params_100_100_6'
|
||||
self.ORI_ROW_CUTS = [0, 0.2, 0.4, 0.6, 0.8, 1]
|
||||
self.ORI_ROW_CUTS = [0, 0.28, 0.43, 0.62, 0.77, 1]
|
||||
self.ORI_COL_CUTS = [0, 0.2, 0.4, 0.5, 0.7, 0.8, 1]
|
||||
self.CUT_NUM = 9
|
||||
self.BASE_LINE = 19757.42
|
||||
self.MAX_ADJUST_STEP = 80
|
||||
self.CUT_NUM = 5
|
||||
self.BASE_LINE = 19376.06
|
||||
self.MAX_ADJUST_STEP = 50
|
||||
# self.ADJUST_THRESHOLD = 0.1
|
||||
# self.mTSP_STEPS = 10000
|
||||
|
||||
@ -31,7 +31,7 @@ class PartitionEnv(gym.Env):
|
||||
self.action_space = spaces.Discrete(self.CUT_NUM*2 + 1)
|
||||
# 定义观察空间为8维向量
|
||||
self.observation_space = spaces.Box(
|
||||
low=0.0, high=1.0, shape=(self.CUT_NUM + 4,), dtype=np.float32)
|
||||
low=0.0, high=1.0, shape=(len(self.ORI_ROW_CUTS)+len(self.ORI_COL_CUTS),), dtype=np.float32)
|
||||
|
||||
self.row_cuts = self.ORI_ROW_CUTS[:]
|
||||
self.col_cuts = self.ORI_COL_CUTS[:]
|
||||
@ -94,16 +94,16 @@ class PartitionEnv(gym.Env):
|
||||
cut_index, signal = (action + 1) // 2, (action + 1) % 2
|
||||
if action == 0:
|
||||
pass
|
||||
elif cut_index <= 4:
|
||||
elif cut_index <= 5:
|
||||
if signal == 0:
|
||||
self.row_cuts[cut_index] += 0.01
|
||||
self.col_cuts[cut_index] += 0.005
|
||||
else:
|
||||
self.row_cuts[cut_index] -= 0.01
|
||||
self.col_cuts[cut_index] -= 0.005
|
||||
else:
|
||||
if signal == 0:
|
||||
self.col_cuts[cut_index-4] += 0.01
|
||||
self.col_cuts[cut_index-4] += 0.005
|
||||
else:
|
||||
self.col_cuts[cut_index-4] -= 0.01
|
||||
self.col_cuts[cut_index-4] -= 0.005
|
||||
|
||||
# 检查row_cuts和col_cuts是否按升序排列
|
||||
if (all(self.row_cuts[i] < self.row_cuts[i+1] for i in range(len(self.row_cuts)-1)) and
|
||||
|
@ -13,7 +13,7 @@ print('state:', state)
|
||||
# action_series = [1] * 30
|
||||
# action_series = [[0.2], [0.4], [0.7], [0.5]]
|
||||
# action_series = [[-0.08], [-0.08], [0], [0]]
|
||||
action_series = list(range(19))
|
||||
action_series = list(range(11))
|
||||
|
||||
for i in range(100):
|
||||
action = action_series[i]
|
||||
|
40
plot_dqn_training.py
Normal file
40
plot_dqn_training.py
Normal file
@ -0,0 +1,40 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
def plot_dqn_training_curve(csv_file):
|
||||
# 读取CSV文件
|
||||
data = pd.read_csv(csv_file)
|
||||
|
||||
# 检查是否包含所需列
|
||||
if 'Step' not in data.columns or 'Value' not in data.columns:
|
||||
raise ValueError("CSV文件必须包含'step'和'value'列")
|
||||
|
||||
# 提取数据
|
||||
steps = data['Step']
|
||||
rewards = data['Value']
|
||||
|
||||
# Plot the curve
|
||||
plt.plot(steps, rewards, label='Reward', color='blue')
|
||||
|
||||
# Set title and axis labels
|
||||
# plt.title("DQN Training Curve", fontsize=16)
|
||||
plt.xlabel("Training Steps (w)", fontsize=14)
|
||||
plt.ylabel("Reward", fontsize=14)
|
||||
|
||||
# Adjust x-axis ticks dynamically based on step range
|
||||
step_min, step_max = steps.min(), steps.max()
|
||||
tick_interval = (step_max - step_min) // 10 # Divide into 10 intervals
|
||||
ticks = np.arange(step_min, step_max + tick_interval, tick_interval)
|
||||
plt.xticks(ticks=ticks, labels=[f"{x//10000}w" for x in ticks])
|
||||
|
||||
# Add grid and legend
|
||||
plt.grid(True, linestyle='--', alpha=0.7)
|
||||
|
||||
# Show the plot
|
||||
plt.show()
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 替换为你的CSV文件路径
|
||||
csv_file = r"runs\DQN-PartEnv_S0_ 2025-04-02 20_13\DQN-PartEnv_S0_ 2025-04-02 20_13.csv"
|
||||
plot_dqn_training_curve(csv_file)
|
@ -1,11 +1,11 @@
|
||||
{
|
||||
"best_time": 19643.795059416032,
|
||||
"best_time": 19376.05694186515,
|
||||
"row_cuts": [
|
||||
0,
|
||||
0.2,
|
||||
0.4,
|
||||
0.6,
|
||||
0.78,
|
||||
0.2800000000000001,
|
||||
0.43000000000000005,
|
||||
0.62,
|
||||
0.77,
|
||||
1
|
||||
],
|
||||
"col_cuts": [
|
||||
@ -56,5 +56,5 @@
|
||||
19,
|
||||
0
|
||||
],
|
||||
"timestamp": "2025-04-05 11:03:20"
|
||||
"timestamp": "2025-04-06 09:10:53"
|
||||
}
|
@ -1,73 +0,0 @@
|
||||
{
|
||||
"row_boundaries": [
|
||||
0.0,
|
||||
0.2,
|
||||
0.4,
|
||||
0.6000000000000001,
|
||||
0.8,
|
||||
1.0
|
||||
],
|
||||
"col_boundaries": [
|
||||
0,
|
||||
0.12533333333333332,
|
||||
0.25066666666666665,
|
||||
0.376,
|
||||
0.5013333333333333,
|
||||
0.6266666666666666,
|
||||
0.7519999999999999,
|
||||
0.8773333333333332,
|
||||
1
|
||||
],
|
||||
"car_paths": [
|
||||
[
|
||||
0,
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
4,
|
||||
5,
|
||||
6,
|
||||
7
|
||||
],
|
||||
[
|
||||
8,
|
||||
9,
|
||||
10,
|
||||
11,
|
||||
12,
|
||||
13,
|
||||
14,
|
||||
15
|
||||
],
|
||||
[
|
||||
16,
|
||||
17,
|
||||
18,
|
||||
19,
|
||||
20,
|
||||
21,
|
||||
22,
|
||||
23
|
||||
],
|
||||
[
|
||||
24,
|
||||
25,
|
||||
26,
|
||||
27,
|
||||
28,
|
||||
29,
|
||||
30,
|
||||
31
|
||||
],
|
||||
[
|
||||
32,
|
||||
33,
|
||||
34,
|
||||
35,
|
||||
36,
|
||||
37,
|
||||
38,
|
||||
39
|
||||
]
|
||||
]
|
||||
}
|
@ -14,9 +14,9 @@ def visualize_solution(row_boundaries, col_boundaries, car_paths_coords, W, H, r
|
||||
|
||||
# 设置英文标题和标签
|
||||
# ax.set_title("Monte Carlo", fontsize=12)
|
||||
ax.set_title("Greedy", fontsize=12)
|
||||
# ax.set_title("Greedy", fontsize=12)
|
||||
# ax.set_title("Enumeration-Genetic Algorithm", fontsize=12)
|
||||
# ax.set_title("DQN fine-tuning", fontsize=12)
|
||||
ax.set_title("DQN fine-tuning", fontsize=12)
|
||||
|
||||
ax.set_xlabel("Region Width", fontsize=10)
|
||||
ax.set_ylabel("Region Height", fontsize=10)
|
||||
@ -200,7 +200,7 @@ if __name__ == "__main__":
|
||||
# 需要修改的超参数
|
||||
# ---------------------------
|
||||
params_file = 'params_100_100_6'
|
||||
solution_file = r'solutions\greedy_params_100_100_6.json'
|
||||
solution_file = r'solutions\finetune_params_100_100_6.json'
|
||||
|
||||
with open(params_file + '.yml', 'r', encoding='utf-8') as file:
|
||||
params = yaml.safe_load(file)
|
||||
|
Loading…
Reference in New Issue
Block a user