Files
bullet3/examples/pybullet/gym/pybullet_envs/agents/tools/batch_env.py
Erwin Coumans ef9570c315 add yapf style and apply yapf to format all Python files
This recreates pull request #2192
2019-04-27 07:31:15 -07:00

120 lines
4.0 KiB
Python

# Copyright 2017 The TensorFlow Agents Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Combine multiple environments to step them in batch."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
class BatchEnv(object):
"""Combine multiple environments to step them in batch."""
def __init__(self, envs, blocking):
"""Combine multiple environments to step them in batch.
To step environments in parallel, environments must support a
`blocking=False` argument to their step and reset functions that makes them
return callables instead to receive the result at a later time.
Args:
envs: List of environments.
blocking: Step environments after another rather than in parallel.
Raises:
ValueError: Environments have different observation or action spaces.
"""
self._envs = envs
self._blocking = blocking
observ_space = self._envs[0].observation_space
if not all(env.observation_space == observ_space for env in self._envs):
raise ValueError('All environments must use the same observation space.')
action_space = self._envs[0].action_space
if not all(env.action_space == action_space for env in self._envs):
raise ValueError('All environments must use the same observation space.')
def __len__(self):
"""Number of combined environments."""
return len(self._envs)
def __getitem__(self, index):
"""Access an underlying environment by index."""
return self._envs[index]
def __getattr__(self, name):
"""Forward unimplemented attributes to one of the original environments.
Args:
name: Attribute that was accessed.
Returns:
Value behind the attribute name one of the wrapped environments.
"""
return getattr(self._envs[0], name)
def step(self, actions):
"""Forward a batch of actions to the wrapped environments.
Args:
actions: Batched action to apply to the environment.
Raises:
ValueError: Invalid actions.
Returns:
Batch of observations, rewards, and done flags.
"""
for index, (env, action) in enumerate(zip(self._envs, actions)):
if not env.action_space.contains(action):
message = 'Invalid action at index {}: {}'
raise ValueError(message.format(index, action))
if self._blocking:
transitions = [env.step(action) for env, action in zip(self._envs, actions)]
else:
transitions = [env.step(action, blocking=False) for env, action in zip(self._envs, actions)]
transitions = [transition() for transition in transitions]
observs, rewards, dones, infos = zip(*transitions)
observ = np.stack(observs)
reward = np.stack(rewards)
done = np.stack(dones)
info = tuple(infos)
return observ, reward, done, info
def reset(self, indices=None):
"""Reset the environment and convert the resulting observation.
Args:
indices: The batch indices of environments to reset; defaults to all.
Returns:
Batch of observations.
"""
if indices is None:
indices = np.arange(len(self._envs))
if self._blocking:
observs = [self._envs[index].reset() for index in indices]
else:
observs = [self._envs[index].reset(blocking=False) for index in indices]
observs = [observ() for observ in observs]
observ = np.stack(observs)
return observ
def close(self):
"""Send close messages to the external process and join them."""
for env in self._envs:
if hasattr(env, 'close'):
env.close()