可视化reward
This commit is contained in:
parent
90ad3e829d
commit
d64ec83042
14
GA/ga.py
14
GA/ga.py
@ -42,14 +42,14 @@ class GA(object):
|
|||||||
# print("Lens:", len(result), len(result[0]))
|
# print("Lens:", len(result), len(result[0]))
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def greedy_init(self, dis_mat, num_total, num_city):
|
def greedy_init(self):
|
||||||
start_index = 0
|
start_index = 0
|
||||||
result = []
|
result = []
|
||||||
for i in range(num_total):
|
for i in range(self.num_total):
|
||||||
rest = [x for x in range(0, num_city)]
|
rest = [x for x in range(0, self.num_city)]
|
||||||
# 所有起始点都已经生成了
|
# 所有起始点都已经生成了
|
||||||
if start_index >= num_city:
|
if start_index >= self.num_city:
|
||||||
start_index = np.random.randint(0, num_city)
|
start_index = np.random.randint(0, self.num_city)
|
||||||
result.append(result[start_index].copy())
|
result.append(result[start_index].copy())
|
||||||
continue
|
continue
|
||||||
current = start_index
|
current = start_index
|
||||||
@ -61,8 +61,8 @@ class GA(object):
|
|||||||
tmp_choose = -1
|
tmp_choose = -1
|
||||||
for x in rest:
|
for x in rest:
|
||||||
# print("---", current, x, dis_mat[current][x])
|
# print("---", current, x, dis_mat[current][x])
|
||||||
if dis_mat[current][x] < tmp_min:
|
if self.dis_mat[current][x] < tmp_min:
|
||||||
tmp_min = dis_mat[current][x]
|
tmp_min = self.dis_mat[current][x]
|
||||||
tmp_choose = x
|
tmp_choose = x
|
||||||
if tmp_choose == -1: # 此种情况仅可能发生在剩的都是基地点
|
if tmp_choose == -1: # 此种情况仅可能发生在剩的都是基地点
|
||||||
tmp_choose = rest[0]
|
tmp_choose = rest[0]
|
||||||
|
@ -19,11 +19,11 @@ class PartitionEnv(gym.Env):
|
|||||||
# 可能需要手动修改的超参数
|
# 可能需要手动修改的超参数
|
||||||
##############################
|
##############################
|
||||||
self.params = 'params_100_100_6'
|
self.params = 'params_100_100_6'
|
||||||
self.ORI_ROW_CUTS = [0, 0.2, 0.4, 0.6, 0.8, 1]
|
self.ORI_ROW_CUTS = [0, 0.28, 0.43, 0.62, 0.77, 1]
|
||||||
self.ORI_COL_CUTS = [0, 0.2, 0.4, 0.5, 0.7, 0.8, 1]
|
self.ORI_COL_CUTS = [0, 0.2, 0.4, 0.5, 0.7, 0.8, 1]
|
||||||
self.CUT_NUM = 9
|
self.CUT_NUM = 5
|
||||||
self.BASE_LINE = 19757.42
|
self.BASE_LINE = 19376.06
|
||||||
self.MAX_ADJUST_STEP = 80
|
self.MAX_ADJUST_STEP = 50
|
||||||
# self.ADJUST_THRESHOLD = 0.1
|
# self.ADJUST_THRESHOLD = 0.1
|
||||||
# self.mTSP_STEPS = 10000
|
# self.mTSP_STEPS = 10000
|
||||||
|
|
||||||
@ -31,7 +31,7 @@ class PartitionEnv(gym.Env):
|
|||||||
self.action_space = spaces.Discrete(self.CUT_NUM*2 + 1)
|
self.action_space = spaces.Discrete(self.CUT_NUM*2 + 1)
|
||||||
# 定义观察空间为8维向量
|
# 定义观察空间为8维向量
|
||||||
self.observation_space = spaces.Box(
|
self.observation_space = spaces.Box(
|
||||||
low=0.0, high=1.0, shape=(self.CUT_NUM + 4,), dtype=np.float32)
|
low=0.0, high=1.0, shape=(len(self.ORI_ROW_CUTS)+len(self.ORI_COL_CUTS),), dtype=np.float32)
|
||||||
|
|
||||||
self.row_cuts = self.ORI_ROW_CUTS[:]
|
self.row_cuts = self.ORI_ROW_CUTS[:]
|
||||||
self.col_cuts = self.ORI_COL_CUTS[:]
|
self.col_cuts = self.ORI_COL_CUTS[:]
|
||||||
@ -94,16 +94,16 @@ class PartitionEnv(gym.Env):
|
|||||||
cut_index, signal = (action + 1) // 2, (action + 1) % 2
|
cut_index, signal = (action + 1) // 2, (action + 1) % 2
|
||||||
if action == 0:
|
if action == 0:
|
||||||
pass
|
pass
|
||||||
elif cut_index <= 4:
|
elif cut_index <= 5:
|
||||||
if signal == 0:
|
if signal == 0:
|
||||||
self.row_cuts[cut_index] += 0.01
|
self.col_cuts[cut_index] += 0.005
|
||||||
else:
|
else:
|
||||||
self.row_cuts[cut_index] -= 0.01
|
self.col_cuts[cut_index] -= 0.005
|
||||||
else:
|
else:
|
||||||
if signal == 0:
|
if signal == 0:
|
||||||
self.col_cuts[cut_index-4] += 0.01
|
self.col_cuts[cut_index-4] += 0.005
|
||||||
else:
|
else:
|
||||||
self.col_cuts[cut_index-4] -= 0.01
|
self.col_cuts[cut_index-4] -= 0.005
|
||||||
|
|
||||||
# 检查row_cuts和col_cuts是否按升序排列
|
# 检查row_cuts和col_cuts是否按升序排列
|
||||||
if (all(self.row_cuts[i] < self.row_cuts[i+1] for i in range(len(self.row_cuts)-1)) and
|
if (all(self.row_cuts[i] < self.row_cuts[i+1] for i in range(len(self.row_cuts)-1)) and
|
||||||
|
@ -13,7 +13,7 @@ print('state:', state)
|
|||||||
# action_series = [1] * 30
|
# action_series = [1] * 30
|
||||||
# action_series = [[0.2], [0.4], [0.7], [0.5]]
|
# action_series = [[0.2], [0.4], [0.7], [0.5]]
|
||||||
# action_series = [[-0.08], [-0.08], [0], [0]]
|
# action_series = [[-0.08], [-0.08], [0], [0]]
|
||||||
action_series = list(range(19))
|
action_series = list(range(11))
|
||||||
|
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
action = action_series[i]
|
action = action_series[i]
|
||||||
|
40
plot_dqn_training.py
Normal file
40
plot_dqn_training.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
def plot_dqn_training_curve(csv_file):
|
||||||
|
# 读取CSV文件
|
||||||
|
data = pd.read_csv(csv_file)
|
||||||
|
|
||||||
|
# 检查是否包含所需列
|
||||||
|
if 'Step' not in data.columns or 'Value' not in data.columns:
|
||||||
|
raise ValueError("CSV文件必须包含'step'和'value'列")
|
||||||
|
|
||||||
|
# 提取数据
|
||||||
|
steps = data['Step']
|
||||||
|
rewards = data['Value']
|
||||||
|
|
||||||
|
# Plot the curve
|
||||||
|
plt.plot(steps, rewards, label='Reward', color='blue')
|
||||||
|
|
||||||
|
# Set title and axis labels
|
||||||
|
# plt.title("DQN Training Curve", fontsize=16)
|
||||||
|
plt.xlabel("Training Steps (w)", fontsize=14)
|
||||||
|
plt.ylabel("Reward", fontsize=14)
|
||||||
|
|
||||||
|
# Adjust x-axis ticks dynamically based on step range
|
||||||
|
step_min, step_max = steps.min(), steps.max()
|
||||||
|
tick_interval = (step_max - step_min) // 10 # Divide into 10 intervals
|
||||||
|
ticks = np.arange(step_min, step_max + tick_interval, tick_interval)
|
||||||
|
plt.xticks(ticks=ticks, labels=[f"{x//10000}w" for x in ticks])
|
||||||
|
|
||||||
|
# Add grid and legend
|
||||||
|
plt.grid(True, linestyle='--', alpha=0.7)
|
||||||
|
|
||||||
|
# Show the plot
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 替换为你的CSV文件路径
|
||||||
|
csv_file = r"runs\DQN-PartEnv_S0_ 2025-04-02 20_13\DQN-PartEnv_S0_ 2025-04-02 20_13.csv"
|
||||||
|
plot_dqn_training_curve(csv_file)
|
@ -1,11 +1,11 @@
|
|||||||
{
|
{
|
||||||
"best_time": 19643.795059416032,
|
"best_time": 19376.05694186515,
|
||||||
"row_cuts": [
|
"row_cuts": [
|
||||||
0,
|
0,
|
||||||
0.2,
|
0.2800000000000001,
|
||||||
0.4,
|
0.43000000000000005,
|
||||||
0.6,
|
0.62,
|
||||||
0.78,
|
0.77,
|
||||||
1
|
1
|
||||||
],
|
],
|
||||||
"col_cuts": [
|
"col_cuts": [
|
||||||
@ -56,5 +56,5 @@
|
|||||||
19,
|
19,
|
||||||
0
|
0
|
||||||
],
|
],
|
||||||
"timestamp": "2025-04-05 11:03:20"
|
"timestamp": "2025-04-06 09:10:53"
|
||||||
}
|
}
|
@ -1,73 +0,0 @@
|
|||||||
{
|
|
||||||
"row_boundaries": [
|
|
||||||
0.0,
|
|
||||||
0.2,
|
|
||||||
0.4,
|
|
||||||
0.6000000000000001,
|
|
||||||
0.8,
|
|
||||||
1.0
|
|
||||||
],
|
|
||||||
"col_boundaries": [
|
|
||||||
0,
|
|
||||||
0.12533333333333332,
|
|
||||||
0.25066666666666665,
|
|
||||||
0.376,
|
|
||||||
0.5013333333333333,
|
|
||||||
0.6266666666666666,
|
|
||||||
0.7519999999999999,
|
|
||||||
0.8773333333333332,
|
|
||||||
1
|
|
||||||
],
|
|
||||||
"car_paths": [
|
|
||||||
[
|
|
||||||
0,
|
|
||||||
1,
|
|
||||||
2,
|
|
||||||
3,
|
|
||||||
4,
|
|
||||||
5,
|
|
||||||
6,
|
|
||||||
7
|
|
||||||
],
|
|
||||||
[
|
|
||||||
8,
|
|
||||||
9,
|
|
||||||
10,
|
|
||||||
11,
|
|
||||||
12,
|
|
||||||
13,
|
|
||||||
14,
|
|
||||||
15
|
|
||||||
],
|
|
||||||
[
|
|
||||||
16,
|
|
||||||
17,
|
|
||||||
18,
|
|
||||||
19,
|
|
||||||
20,
|
|
||||||
21,
|
|
||||||
22,
|
|
||||||
23
|
|
||||||
],
|
|
||||||
[
|
|
||||||
24,
|
|
||||||
25,
|
|
||||||
26,
|
|
||||||
27,
|
|
||||||
28,
|
|
||||||
29,
|
|
||||||
30,
|
|
||||||
31
|
|
||||||
],
|
|
||||||
[
|
|
||||||
32,
|
|
||||||
33,
|
|
||||||
34,
|
|
||||||
35,
|
|
||||||
36,
|
|
||||||
37,
|
|
||||||
38,
|
|
||||||
39
|
|
||||||
]
|
|
||||||
]
|
|
||||||
}
|
|
@ -14,9 +14,9 @@ def visualize_solution(row_boundaries, col_boundaries, car_paths_coords, W, H, r
|
|||||||
|
|
||||||
# 设置英文标题和标签
|
# 设置英文标题和标签
|
||||||
# ax.set_title("Monte Carlo", fontsize=12)
|
# ax.set_title("Monte Carlo", fontsize=12)
|
||||||
ax.set_title("Greedy", fontsize=12)
|
# ax.set_title("Greedy", fontsize=12)
|
||||||
# ax.set_title("Enumeration-Genetic Algorithm", fontsize=12)
|
# ax.set_title("Enumeration-Genetic Algorithm", fontsize=12)
|
||||||
# ax.set_title("DQN fine-tuning", fontsize=12)
|
ax.set_title("DQN fine-tuning", fontsize=12)
|
||||||
|
|
||||||
ax.set_xlabel("Region Width", fontsize=10)
|
ax.set_xlabel("Region Width", fontsize=10)
|
||||||
ax.set_ylabel("Region Height", fontsize=10)
|
ax.set_ylabel("Region Height", fontsize=10)
|
||||||
@ -200,7 +200,7 @@ if __name__ == "__main__":
|
|||||||
# 需要修改的超参数
|
# 需要修改的超参数
|
||||||
# ---------------------------
|
# ---------------------------
|
||||||
params_file = 'params_100_100_6'
|
params_file = 'params_100_100_6'
|
||||||
solution_file = r'solutions\greedy_params_100_100_6.json'
|
solution_file = r'solutions\finetune_params_100_100_6.json'
|
||||||
|
|
||||||
with open(params_file + '.yml', 'r', encoding='utf-8') as file:
|
with open(params_file + '.yml', 'r', encoding='utf-8') as file:
|
||||||
params = yaml.safe_load(file)
|
params = yaml.safe_load(file)
|
||||||
|
Loading…
Reference in New Issue
Block a user