import matplotlib.pyplot as plt import pandas as pd import numpy as np def plot_dqn_training_curve(csv_file): # 读取CSV文件 data = pd.read_csv(csv_file) # 检查是否包含所需列 if 'Step' not in data.columns or 'Value' not in data.columns: raise ValueError("CSV文件必须包含'step'和'value'列") # 提取数据 steps = data['Step'] rewards = data['Value'] # Plot the curve plt.plot(steps, rewards, label='Reward', color='blue') # Set title and axis labels # plt.title("DQN Training Curve", fontsize=16) plt.xlabel("Training Steps (w)", fontsize=14) plt.ylabel("Reward", fontsize=14) # Adjust x-axis ticks dynamically based on step range step_min, step_max = steps.min(), steps.max() tick_interval = (step_max - step_min) // 10 # Divide into 10 intervals ticks = np.arange(step_min, step_max + tick_interval, tick_interval) plt.xticks(ticks=ticks, labels=[f"{x//10000}w" for x in ticks]) # Add grid and legend plt.grid(True, linestyle='--', alpha=0.7) # Show the plot plt.show() if __name__ == "__main__": # 替换为你的CSV文件路径 csv_file = r"runs\DQN-PartEnv_S0_ 2025-04-02 20_13\DQN-PartEnv_S0_ 2025-04-02 20_13.csv" plot_dqn_training_curve(csv_file)