143 lines
4.9 KiB
Python
143 lines
4.9 KiB
Python
import os
|
|
import pandas as pd
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
def save_graph():
|
|
print("============================================================================================")
|
|
# env_name = 'CartPole-v1'
|
|
# env_name = 'LunarLander-v2'
|
|
# env_name = 'BipedalWalker-v2'
|
|
env_name = 'RoboschoolWalker2d-v1'
|
|
|
|
fig_num = 0 #### change this to prevent overwriting figures in same env_name folder
|
|
plot_avg = True # plot average of all runs; else plot all runs separately
|
|
fig_width = 10
|
|
fig_height = 6
|
|
|
|
# smooth out rewards to get a smooth and a less smooth (var) plot lines
|
|
window_len_smooth = 20
|
|
min_window_len_smooth = 1
|
|
linewidth_smooth = 1.5
|
|
alpha_smooth = 1
|
|
|
|
window_len_var = 5
|
|
min_window_len_var = 1
|
|
linewidth_var = 2
|
|
alpha_var = 0.1
|
|
|
|
colors = ['red', 'blue', 'green', 'orange', 'purple', 'olive', 'brown', 'magenta', 'cyan', 'crimson','gray', 'black']
|
|
|
|
# make directory for saving figures
|
|
figures_dir = "PPO_figs"
|
|
if not os.path.exists(figures_dir):
|
|
os.makedirs(figures_dir)
|
|
|
|
# make environment directory for saving figures
|
|
figures_dir = figures_dir + '/' + env_name + '/'
|
|
if not os.path.exists(figures_dir):
|
|
os.makedirs(figures_dir)
|
|
|
|
fig_save_path = figures_dir + '/PPO_' + env_name + '_fig_' + str(fig_num) + '.png'
|
|
|
|
# get number of log files in directory
|
|
log_dir = "PPO_logs" + '/' + env_name + '/'
|
|
|
|
current_num_files = next(os.walk(log_dir))[2]
|
|
num_runs = len(current_num_files)
|
|
|
|
all_runs = []
|
|
|
|
for run_num in range(num_runs):
|
|
|
|
log_f_name = log_dir + '/PPO_' + env_name + "_log_" + str(run_num) + ".csv"
|
|
print("loading data from : " + log_f_name)
|
|
data = pd.read_csv(log_f_name)
|
|
data = pd.DataFrame(data)
|
|
|
|
print("data shape : ", data.shape)
|
|
|
|
all_runs.append(data)
|
|
print("--------------------------------------------------------------------------------------------")
|
|
|
|
ax = plt.gca()
|
|
|
|
if plot_avg:
|
|
# average all runs
|
|
df_concat = pd.concat(all_runs)
|
|
df_concat_groupby = df_concat.groupby(df_concat.index)
|
|
data_avg = df_concat_groupby.mean()
|
|
|
|
# smooth out rewards to get a smooth and a less smooth (var) plot lines
|
|
data_avg['reward_smooth'] = data_avg['reward'].rolling(window=window_len_smooth, win_type='triang', min_periods=min_window_len_smooth).mean()
|
|
data_avg['reward_var'] = data_avg['reward'].rolling(window=window_len_var, win_type='triang', min_periods=min_window_len_var).mean()
|
|
|
|
data_avg.plot(kind='line', x='timestep' , y='reward_smooth',ax=ax,color=colors[0], linewidth=linewidth_smooth, alpha=alpha_smooth)
|
|
data_avg.plot(kind='line', x='timestep' , y='reward_var',ax=ax,color=colors[0], linewidth=linewidth_var, alpha=alpha_var)
|
|
|
|
# keep only reward_smooth in the legend and rename it
|
|
handles, labels = ax.get_legend_handles_labels()
|
|
ax.legend([handles[0]], ["reward_avg_" + str(len(all_runs)) + "_runs"], loc=2)
|
|
|
|
else:
|
|
for i, run in enumerate(all_runs):
|
|
# smooth out rewards to get a smooth and a less smooth (var) plot lines
|
|
run['reward_smooth_' + str(i)] = run['reward'].rolling(window=window_len_smooth, win_type='triang', min_periods=min_window_len_smooth).mean()
|
|
run['reward_var_' + str(i)] = run['reward'].rolling(window=window_len_var, win_type='triang', min_periods=min_window_len_var).mean()
|
|
|
|
# plot the lines
|
|
run.plot(kind='line', x='timestep' , y='reward_smooth_' + str(i),ax=ax,color=colors[i % len(colors)], linewidth=linewidth_smooth, alpha=alpha_smooth)
|
|
run.plot(kind='line', x='timestep' , y='reward_var_' + str(i),ax=ax,color=colors[i % len(colors)], linewidth=linewidth_var, alpha=alpha_var)
|
|
|
|
# keep alternate elements (reward_smooth_i) in the legend
|
|
handles, labels = ax.get_legend_handles_labels()
|
|
new_handles = []
|
|
new_labels = []
|
|
for i in range(len(handles)):
|
|
if(i%2 == 0):
|
|
new_handles.append(handles[i])
|
|
new_labels.append(labels[i])
|
|
ax.legend(new_handles, new_labels, loc=2)
|
|
|
|
# ax.set_yticks(np.arange(0, 1800, 200))
|
|
# ax.set_xticks(np.arange(0, int(4e6), int(5e5)))
|
|
|
|
ax.grid(color='gray', linestyle='-', linewidth=1, alpha=0.2)
|
|
|
|
ax.set_xlabel("Timesteps", fontsize=12)
|
|
ax.set_ylabel("Rewards", fontsize=12)
|
|
|
|
plt.title(env_name, fontsize=14)
|
|
|
|
fig = plt.gcf()
|
|
fig.set_size_inches(fig_width, fig_height)
|
|
|
|
print("============================================================================================")
|
|
plt.savefig(fig_save_path)
|
|
print("figure saved at : ", fig_save_path)
|
|
print("============================================================================================")
|
|
|
|
plt.show()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
save_graph()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|