97 lines
2.9 KiB
Python
97 lines
2.9 KiB
Python
# These tags allow extracting portions of this script on Anyscale.
|
|
# ws-template-imports-start
|
|
import gymnasium as gym
|
|
|
|
from ray import tune
|
|
from ray.rllib.algorithms.ppo import PPOConfig
|
|
from ray.rllib.connectors.env_to_module.frame_stacking import FrameStackingEnvToModule
|
|
from ray.rllib.connectors.learner.frame_stacking import FrameStackingLearner
|
|
from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig
|
|
from ray.rllib.env.wrappers.atari_wrappers import wrap_atari_for_new_api_stack
|
|
from ray.rllib.utils.test_utils import add_rllib_example_script_args
|
|
|
|
# ws-template-imports-end
|
|
|
|
parser = add_rllib_example_script_args(
|
|
default_reward=float("inf"),
|
|
default_timesteps=3000000,
|
|
default_iters=100000000000,
|
|
)
|
|
parser.set_defaults(
|
|
enable_new_api_stack=True,
|
|
env="ale_py:ALE/Pong-v5",
|
|
)
|
|
# Use `parser` to add your own custom command line options to this script
|
|
# and (if needed) use their values to set up `config` below.
|
|
args = parser.parse_args()
|
|
|
|
NUM_LEARNERS = args.num_learners or 1
|
|
ENV = args.env
|
|
|
|
|
|
# These tags allow extracting portions of this script on Anyscale.
|
|
# ws-template-code-start
|
|
def _make_env_to_module_connector(env):
|
|
return FrameStackingEnvToModule(num_frames=4)
|
|
|
|
|
|
def _make_learner_connector(input_observation_space, input_action_space):
|
|
return FrameStackingLearner(num_frames=4)
|
|
|
|
|
|
# Create a custom Atari setup (w/o the usual RLlib-hard-coded framestacking in it).
|
|
# We would like our frame stacking connector to do this job.
|
|
def _env_creator(cfg):
|
|
return wrap_atari_for_new_api_stack(
|
|
gym.make(ENV, **cfg, render_mode="rgb_array"),
|
|
# Perform frame-stacking through ConnectorV2 API.
|
|
framestack=None,
|
|
)
|
|
|
|
|
|
tune.register_env("env", _env_creator)
|
|
|
|
config = (
|
|
PPOConfig()
|
|
.environment(
|
|
"env",
|
|
env_config={
|
|
# Make analogous to old v4 + NoFrameskip.
|
|
"frameskip": 1,
|
|
"full_action_space": False,
|
|
"repeat_action_probability": 0.0,
|
|
},
|
|
clip_rewards=True,
|
|
)
|
|
.env_runners(
|
|
env_to_module_connector=_make_env_to_module_connector,
|
|
)
|
|
.training(
|
|
learner_connector=_make_learner_connector,
|
|
train_batch_size_per_learner=4000,
|
|
minibatch_size=128,
|
|
lambda_=0.95,
|
|
kl_coeff=0.5,
|
|
clip_param=0.1,
|
|
vf_clip_param=10.0,
|
|
entropy_coeff=0.01,
|
|
num_epochs=10,
|
|
lr=0.00015 * NUM_LEARNERS,
|
|
grad_clip=100.0,
|
|
grad_clip_by="global_norm",
|
|
)
|
|
.rl_module(
|
|
model_config=DefaultModelConfig(
|
|
conv_filters=[[16, 4, 2], [32, 4, 2], [64, 4, 2], [128, 4, 2]],
|
|
conv_activation="relu",
|
|
head_fcnet_hiddens=[256],
|
|
vf_share_layers=True,
|
|
),
|
|
)
|
|
)
|
|
# ws-template-code-end
|
|
|
|
if __name__ == "__main__":
|
|
from ray.rllib.utils.test_utils import run_rllib_example_script_experiment
|
|
|
|
run_rllib_example_script_experiment(config, args=args) |