HPCC2025/plot_dqn_training.py

41 lines
1.3 KiB
Python
Raw Normal View History

2025-04-08 15:49:22 +08:00
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
def plot_dqn_training_curve(csv_file):
# 读取CSV文件
data = pd.read_csv(csv_file)
# 检查是否包含所需列
if 'Step' not in data.columns or 'Value' not in data.columns:
raise ValueError("CSV文件必须包含'step''value'")
# 提取数据
steps = data['Step']
rewards = data['Value']
# Plot the curve
plt.plot(steps, rewards, label='Reward', color='blue')
# Set title and axis labels
# plt.title("DQN Training Curve", fontsize=16)
plt.xlabel("Training Steps (w)", fontsize=14)
plt.ylabel("Reward", fontsize=14)
# Adjust x-axis ticks dynamically based on step range
step_min, step_max = steps.min(), steps.max()
tick_interval = (step_max - step_min) // 10 # Divide into 10 intervals
ticks = np.arange(step_min, step_max + tick_interval, tick_interval)
plt.xticks(ticks=ticks, labels=[f"{x//10000}w" for x in ticks])
# Add grid and legend
plt.grid(True, linestyle='--', alpha=0.7)
# Show the plot
plt.show()
if __name__ == "__main__":
# 替换为你的CSV文件路径
csv_file = r"runs\DQN-PartEnv_S0_ 2025-04-02 20_13\DQN-PartEnv_S0_ 2025-04-02 20_13.csv"
plot_dqn_training_curve(csv_file)