Files
bullet3/examples/pybullet/gym/pybullet_envs/stable_baselines/utils.py
2020-01-02 11:00:45 +01:00

115 lines
4.3 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Code adapted from https://github.com/araffin/rl-baselines-zoo
# it requires stable-baselines to be installed
# Author: Antonin RAFFIN
# MIT License
import gym
import numpy as np
from gym.wrappers import TimeLimit
from stable_baselines.common.evaluation import evaluate_policy
class TimeFeatureWrapper(gym.Wrapper):
"""
Add remaining time to observation space for fixed length episodes.
See https://arxiv.org/abs/1712.00378 and https://github.com/aravindr93/mjrl/issues/13.
:param env: (gym.Env)
:param max_steps: (int) Max number of steps of an episode
if it is not wrapped in a TimeLimit object.
:param test_mode: (bool) In test mode, the time feature is constant,
equal to zero. This allow to check that the agent did not overfit this feature,
learning a deterministic pre-defined sequence of actions.
"""
def __init__(self, env, max_steps=1000, test_mode=False):
assert isinstance(env.observation_space, gym.spaces.Box)
# Add a time feature to the observation
low, high = env.observation_space.low, env.observation_space.high
low, high= np.concatenate((low, [0])), np.concatenate((high, [1.]))
env.observation_space = gym.spaces.Box(low=low, high=high, dtype=np.float32)
super(TimeFeatureWrapper, self).__init__(env)
if isinstance(env, TimeLimit):
self._max_steps = env._max_episode_steps
else:
self._max_steps = max_steps
self._current_step = 0
self._test_mode = test_mode
def reset(self):
self._current_step = 0
return self._get_obs(self.env.reset())
def step(self, action):
self._current_step += 1
obs, reward, done, info = self.env.step(action)
return self._get_obs(obs), reward, done, info
def _get_obs(self, obs):
"""
Concatenate the time feature to the current observation.
:param obs: (np.ndarray)
:return: (np.ndarray)
"""
# Remaining time is more general
time_feature = 1 - (self._current_step / self._max_steps)
if self._test_mode:
time_feature = 1.0
# Optionnaly: concatenate [time_feature, time_feature ** 2]
return np.concatenate((obs, [time_feature]))
class EvalCallback(object):
"""
Callback for evaluating an agent.
:param eval_env: (gym.Env) The environment used for initialization
:param n_eval_episodes: (int) The number of episodes to test the agent
:param eval_freq: (int) Evaluate the agent every eval_freq call of the callback.
:param deterministic: (bool)
:param best_model_save_path: (str)
:param verbose: (int)
"""
def __init__(self, eval_env, n_eval_episodes=5, eval_freq=10000,
deterministic=True, best_model_save_path=None, verbose=1):
super(EvalCallback, self).__init__()
self.n_eval_episodes = n_eval_episodes
self.eval_freq = eval_freq
self.best_mean_reward = -np.inf
self.deterministic = deterministic
self.eval_env = eval_env
self.verbose = verbose
self.model, self.num_timesteps = None, 0
self.best_model_save_path = best_model_save_path
self.n_calls = 0
def __call__(self, locals_, globals_):
"""
:param locals_: (dict)
:param globals_: (dict)
:return: (bool)
"""
self.n_calls += 1
self.model = locals_['self']
self.num_timesteps = self.model.num_timesteps
if self.n_calls % self.eval_freq == 0:
episode_rewards, _ = evaluate_policy(self.model, self.eval_env, n_eval_episodes=self.n_eval_episodes,
deterministic=self.deterministic, return_episode_rewards=True)
mean_reward, std_reward = np.mean(episode_rewards), np.std(episode_rewards)
if self.verbose > 0:
print("Eval num_timesteps={}, "
"episode_reward={:.2f} +/- {:.2f}".format(self.num_timesteps, mean_reward, std_reward))
if mean_reward > self.best_mean_reward:
if self.best_model_save_path is not None:
print("Saving best model")
self.model.save(self.best_model_save_path)
self.best_mean_reward = mean_reward
return True