Add Stable-Baselines example with SAC and TD3
This commit is contained in:
@@ -0,0 +1,69 @@
|
||||
# Code adapted from https://github.com/araffin/rl-baselines-zoo
|
||||
# it requires stable-baselines to be installed
|
||||
# Colab Notebook:
|
||||
# Author: Antonin RAFFIN
|
||||
# MIT License
|
||||
|
||||
# Add parent dir to find package. Only needed for source code build, pip install doesn't need it.
|
||||
import os
|
||||
import inspect
|
||||
import argparse
|
||||
|
||||
try:
|
||||
import pybullet_envs
|
||||
except ImportError:
|
||||
current_dir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
|
||||
parent_dir = os.path.dirname(os.path.dirname(current_dir))
|
||||
os.sys.path.insert(0, parent_dir)
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
from stable_baselines import SAC, TD3
|
||||
from stable_baselines.common.noise import NormalActionNoise
|
||||
|
||||
from utils import TimeFeatureWrapper, EvalCallback
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--algo', help='RL Algorithm (Soft Actor-Critic by default)', default='sac',
|
||||
type=str, required=False, choices=['sac', 'td3'])
|
||||
parser.add_argument('--env', type=str, default='HalfCheetahBulletEnv-v0', help='environment ID')
|
||||
parser.add_argument('-n', '--n-timesteps', help='Number of training timesteps', default=int(1e6),
|
||||
type=int)
|
||||
args = parser.parse_args()
|
||||
|
||||
env_id = args.env
|
||||
n_timesteps = args.n_timesteps
|
||||
save_path = '{}_{}'.format(args.algo, env_id)
|
||||
|
||||
# Instantiate and wrap the environment
|
||||
env = TimeFeatureWrapper(gym.make(env_id))
|
||||
|
||||
# Create the evaluation environment and callback
|
||||
eval_env = TimeFeatureWrapper(gym.make(env_id))
|
||||
callback = EvalCallback(eval_env, best_model_save_path=save_path + '_best')
|
||||
|
||||
algo = {
|
||||
'sac': SAC,
|
||||
'td3': TD3
|
||||
}[args.algo]
|
||||
|
||||
n_actions = env.action_space.shape[0]
|
||||
|
||||
# Tuned hyperparameters from https://github.com/araffin/rl-baselines-zoo
|
||||
hyperparams = {
|
||||
'sac': dict(batch_size=256, gamma=0.98, policy_kwargs=dict(layers=[256, 256]),
|
||||
learning_starts=10000, buffer_size=int(1e6), tau=0.01),
|
||||
|
||||
'td3': dict(batch_size=100, policy_kwargs=dict(layers=[400, 300]),
|
||||
learning_rate=1e-3, learning_starts=10000, buffer_size=int(1e6),
|
||||
train_freq=1000, gradient_steps=1000,
|
||||
action_noise=NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions)))
|
||||
}[args.algo]
|
||||
|
||||
model = algo('MlpPolicy', env, verbose=1, **hyperparams)
|
||||
model.learn(n_timesteps, callback=callback)
|
||||
|
||||
print("Saving to {}.zip".format(save_path))
|
||||
model.save(save_path)
|
||||
Reference in New Issue
Block a user