41 lines
1.3 KiB
Python
41 lines
1.3 KiB
Python
![]() |
import matplotlib.pyplot as plt
|
||
|
import pandas as pd
|
||
|
import numpy as np
|
||
|
|
||
|
def plot_dqn_training_curve(csv_file):
|
||
|
# 读取CSV文件
|
||
|
data = pd.read_csv(csv_file)
|
||
|
|
||
|
# 检查是否包含所需列
|
||
|
if 'Step' not in data.columns or 'Value' not in data.columns:
|
||
|
raise ValueError("CSV文件必须包含'step'和'value'列")
|
||
|
|
||
|
# 提取数据
|
||
|
steps = data['Step']
|
||
|
rewards = data['Value']
|
||
|
|
||
|
# Plot the curve
|
||
|
plt.plot(steps, rewards, label='Reward', color='blue')
|
||
|
|
||
|
# Set title and axis labels
|
||
|
# plt.title("DQN Training Curve", fontsize=16)
|
||
|
plt.xlabel("Training Steps (w)", fontsize=14)
|
||
|
plt.ylabel("Reward", fontsize=14)
|
||
|
|
||
|
# Adjust x-axis ticks dynamically based on step range
|
||
|
step_min, step_max = steps.min(), steps.max()
|
||
|
tick_interval = (step_max - step_min) // 10 # Divide into 10 intervals
|
||
|
ticks = np.arange(step_min, step_max + tick_interval, tick_interval)
|
||
|
plt.xticks(ticks=ticks, labels=[f"{x//10000}w" for x in ticks])
|
||
|
|
||
|
# Add grid and legend
|
||
|
plt.grid(True, linestyle='--', alpha=0.7)
|
||
|
|
||
|
# Show the plot
|
||
|
plt.show()
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
# 替换为你的CSV文件路径
|
||
|
csv_file = r"runs\DQN-PartEnv_S0_ 2025-04-02 20_13\DQN-PartEnv_S0_ 2025-04-02 20_13.csv"
|
||
|
plot_dqn_training_curve(csv_file)
|