115 lines
4.3 KiB
Python
115 lines
4.3 KiB
Python
# Code adapted from https://github.com/araffin/rl-baselines-zoo
|
||
# it requires stable-baselines to be installed
|
||
# Author: Antonin RAFFIN
|
||
# MIT License
|
||
import gym
|
||
import numpy as np
|
||
from gym.wrappers import TimeLimit
|
||
|
||
from stable_baselines.common.evaluation import evaluate_policy
|
||
|
||
|
||
class TimeFeatureWrapper(gym.Wrapper):
|
||
"""
|
||
Add remaining time to observation space for fixed length episodes.
|
||
See https://arxiv.org/abs/1712.00378 and https://github.com/aravindr93/mjrl/issues/13.
|
||
|
||
:param env: (gym.Env)
|
||
:param max_steps: (int) Max number of steps of an episode
|
||
if it is not wrapped in a TimeLimit object.
|
||
:param test_mode: (bool) In test mode, the time feature is constant,
|
||
equal to zero. This allow to check that the agent did not overfit this feature,
|
||
learning a deterministic pre-defined sequence of actions.
|
||
"""
|
||
def __init__(self, env, max_steps=1000, test_mode=False):
|
||
assert isinstance(env.observation_space, gym.spaces.Box)
|
||
# Add a time feature to the observation
|
||
low, high = env.observation_space.low, env.observation_space.high
|
||
low, high= np.concatenate((low, [0])), np.concatenate((high, [1.]))
|
||
env.observation_space = gym.spaces.Box(low=low, high=high, dtype=np.float32)
|
||
|
||
super(TimeFeatureWrapper, self).__init__(env)
|
||
|
||
if isinstance(env, TimeLimit):
|
||
self._max_steps = env._max_episode_steps
|
||
else:
|
||
self._max_steps = max_steps
|
||
self._current_step = 0
|
||
self._test_mode = test_mode
|
||
|
||
def reset(self):
|
||
self._current_step = 0
|
||
return self._get_obs(self.env.reset())
|
||
|
||
def step(self, action):
|
||
self._current_step += 1
|
||
obs, reward, done, info = self.env.step(action)
|
||
return self._get_obs(obs), reward, done, info
|
||
|
||
def _get_obs(self, obs):
|
||
"""
|
||
Concatenate the time feature to the current observation.
|
||
|
||
:param obs: (np.ndarray)
|
||
:return: (np.ndarray)
|
||
"""
|
||
# Remaining time is more general
|
||
time_feature = 1 - (self._current_step / self._max_steps)
|
||
if self._test_mode:
|
||
time_feature = 1.0
|
||
# Optionnaly: concatenate [time_feature, time_feature ** 2]
|
||
return np.concatenate((obs, [time_feature]))
|
||
|
||
|
||
class EvalCallback(object):
|
||
"""
|
||
Callback for evaluating an agent.
|
||
|
||
:param eval_env: (gym.Env) The environment used for initialization
|
||
:param n_eval_episodes: (int) The number of episodes to test the agent
|
||
:param eval_freq: (int) Evaluate the agent every eval_freq call of the callback.
|
||
:param deterministic: (bool)
|
||
:param best_model_save_path: (str)
|
||
:param verbose: (int)
|
||
"""
|
||
def __init__(self, eval_env, n_eval_episodes=5, eval_freq=10000,
|
||
deterministic=True, best_model_save_path=None, verbose=1):
|
||
super(EvalCallback, self).__init__()
|
||
self.n_eval_episodes = n_eval_episodes
|
||
self.eval_freq = eval_freq
|
||
self.best_mean_reward = -np.inf
|
||
self.deterministic = deterministic
|
||
self.eval_env = eval_env
|
||
self.verbose = verbose
|
||
self.model, self.num_timesteps = None, 0
|
||
self.best_model_save_path = best_model_save_path
|
||
self.n_calls = 0
|
||
|
||
def __call__(self, locals_, globals_):
|
||
"""
|
||
:param locals_: (dict)
|
||
:param globals_: (dict)
|
||
:return: (bool)
|
||
"""
|
||
self.n_calls += 1
|
||
self.model = locals_['self']
|
||
self.num_timesteps = self.model.num_timesteps
|
||
|
||
if self.n_calls % self.eval_freq == 0:
|
||
episode_rewards, _ = evaluate_policy(self.model, self.eval_env, n_eval_episodes=self.n_eval_episodes,
|
||
deterministic=self.deterministic, return_episode_rewards=True)
|
||
|
||
|
||
mean_reward, std_reward = np.mean(episode_rewards), np.std(episode_rewards)
|
||
if self.verbose > 0:
|
||
print("Eval num_timesteps={}, "
|
||
"episode_reward={:.2f} +/- {:.2f}".format(self.num_timesteps, mean_reward, std_reward))
|
||
|
||
if mean_reward > self.best_mean_reward:
|
||
if self.best_model_save_path is not None:
|
||
print("Saving best model")
|
||
self.model.save(self.best_model_save_path)
|
||
self.best_mean_reward = mean_reward
|
||
|
||
return True
|