HPCC2025/SAC_Continuous/main.py
weixin_46229132 5b468deb9d SAC
2025-03-21 16:04:42 +08:00

156 lines
6.4 KiB
Python

from utils import str2bool, evaluate_policy, Action_adapter, Action_adapter_reverse, Reward_adapter
from datetime import datetime
from SAC import SAC_countinuous
import gymnasium as gym
import os
import shutil
import argparse
import torch
# fmt: off
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from env import PartitionMazeEnv
# fmt: on
'''Hyperparameter Setting'''
parser = argparse.ArgumentParser()
parser.add_argument('--dvc', type=str, default='cpu',
help='running device: cuda or cpu')
parser.add_argument('--EnvIdex', type=int, default=0,
help='PV1, Lch_Cv2, Humanv4, HCv4, BWv3, BWHv3')
parser.add_argument('--write', type=str2bool, default=True,
help='Use SummaryWriter to record the training')
parser.add_argument('--render', type=str2bool,
default=False, help='Render or Not')
parser.add_argument('--Loadmodel', type=str2bool,
default=False, help='Load pretrained model or Not')
parser.add_argument('--ModelIdex', type=int, default=100,
help='which model to load')
parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument('--Max_train_steps', type=int,
default=int(5e8), help='Max training steps')
parser.add_argument('--save_interval', type=int,
default=int(5e5), help='Model saving interval, in steps.')
parser.add_argument('--eval_interval', type=int, default=int(5e3),
help='Model evaluating interval, in steps.')
parser.add_argument('--update_every', type=int, default=50,
help='Training Fraquency, in stpes')
parser.add_argument('--gamma', type=float, default=0.99,
help='Discounted Factor')
parser.add_argument('--net_width', type=int, default=256,
help='Hidden net width, s_dim-400-300-a_dim')
parser.add_argument('--a_lr', type=float, default=3e-4,
help='Learning rate of actor')
parser.add_argument('--c_lr', type=float, default=3e-4,
help='Learning rate of critic')
parser.add_argument('--batch_size', type=int, default=256,
help='batch_size of training')
parser.add_argument('--alpha', type=float, default=0.12,
help='Entropy coefficient')
parser.add_argument('--adaptive_alpha', type=str2bool,
default=True, help='Use adaptive_alpha or Not')
opt = parser.parse_args()
opt.dvc = torch.device(opt.dvc) # from str to torch.device
print(opt)
def main():
EnvName = ['PartitionMaze_SAC_Continuous', 'Pendulum-v1', 'LunarLanderContinuous-v2',
'Humanoid-v4', 'HalfCheetah-v4', 'BipedalWalker-v3', 'BipedalWalkerHardcore-v3']
BrifEnvName = ['PM_SAC_Con', 'PV1', 'LLdV2',
'Humanv4', 'HCv4', 'BWv3', 'BWHv3']
# Build Env
# env = gym.make(EnvName[opt.EnvIdex],
# render_mode="human" if opt.render else None)
# eval_env = gym.make(EnvName[opt.EnvIdex])
env = PartitionMazeEnv()
eval_env = PartitionMazeEnv()
opt.state_dim = env.observation_space.shape[0]
opt.action_dim = env.action_space.shape[0]
# remark: action space【-max,max】
opt.max_action = float(env.action_space.high[0])
opt.max_e_steps = 100
print(f'Env:{EnvName[opt.EnvIdex]} state_dim:{opt.state_dim} action_dim:{opt.action_dim} '
f'max_a:{opt.max_action} min_a:{env.action_space.low[0]}')
# Seed Everything
env_seed = opt.seed
torch.manual_seed(opt.seed)
torch.cuda.manual_seed(opt.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
print("Random Seed: {}".format(opt.seed))
# Build SummaryWriter to record training curves
if opt.write:
from torch.utils.tensorboard import SummaryWriter
timenow = str(datetime.now())[0:-10]
timenow = ' ' + timenow[0:13] + '_' + timenow[-2::]
writepath = 'logs/{}'.format(BrifEnvName[opt.EnvIdex]) + timenow
if os.path.exists(writepath):
shutil.rmtree(writepath)
writer = SummaryWriter(log_dir=writepath)
# var: transfer argparse to dictionary
agent = SAC_countinuous(**vars(opt))
if opt.Loadmodel:
agent.load(BrifEnvName[opt.EnvIdex], opt.ModelIdex)
if opt.render:
while True:
score = evaluate_policy(env, agent, turns=1)
print('EnvName:', BrifEnvName[opt.EnvIdex], 'score:', score)
else:
total_steps = 0
while total_steps < opt.Max_train_steps:
# Do not use opt.seed directly, or it can overfit to opt.seed
s = env.reset(seed=env_seed)
env_seed += 1
done = False
'''Interact & trian'''
while not done:
if total_steps < (5*opt.max_e_steps):
a = env.action_space.sample() # act∈[-max,max]
# a = Action_adapter_reverse(act, opt.max_action) # a∈[-1,1]
else:
a = agent.select_action(s, deterministic=False) # a∈[-1,1]
a = Action_adapter(a, opt.max_action) # act∈[-max,max]
s_next, r, dw, tr, info = env.step(
a) # dw: dead&win; tr: truncated
r = Reward_adapter(r, opt.EnvIdex)
done = (dw or tr)
agent.replay_buffer.add(s, a, r, s_next, dw)
s = s_next
total_steps += 1
'''train if it's time'''
# train 50 times every 50 steps rather than 1 training per step. Better!
if (total_steps >= 2*opt.max_e_steps) and (total_steps % opt.update_every == 0):
for j in range(opt.update_every):
agent.train()
'''record & log'''
if total_steps % opt.eval_interval == 0:
ep_r = evaluate_policy(eval_env, agent, turns=3)
if opt.write:
writer.add_scalar(
'ep_r', ep_r, global_step=total_steps)
print(
f'EnvName:{BrifEnvName[opt.EnvIdex]}, Steps: {int(total_steps/1000)}k, Episode Reward:{ep_r}')
'''save model'''
if total_steps % opt.save_interval == 0:
agent.save(BrifEnvName[opt.EnvIdex], int(total_steps/1000))
env.close()
eval_env.close()
if __name__ == '__main__':
main()