# Copyright 2017 The TensorFlow Agents Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Combine multiple environments to step them in batch.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np class BatchEnv(object): """Combine multiple environments to step them in batch.""" def __init__(self, envs, blocking): """Combine multiple environments to step them in batch. To step environments in parallel, environments must support a `blocking=False` argument to their step and reset functions that makes them return callables instead to receive the result at a later time. Args: envs: List of environments. blocking: Step environments after another rather than in parallel. Raises: ValueError: Environments have different observation or action spaces. """ self._envs = envs self._blocking = blocking observ_space = self._envs[0].observation_space if not all(env.observation_space == observ_space for env in self._envs): raise ValueError('All environments must use the same observation space.') action_space = self._envs[0].action_space if not all(env.action_space == action_space for env in self._envs): raise ValueError('All environments must use the same observation space.') def __len__(self): """Number of combined environments.""" return len(self._envs) def __getitem__(self, index): """Access an underlying environment by index.""" return self._envs[index] def __getattr__(self, name): """Forward unimplemented attributes to one of the original environments. Args: name: Attribute that was accessed. Returns: Value behind the attribute name one of the wrapped environments. """ return getattr(self._envs[0], name) def step(self, actions): """Forward a batch of actions to the wrapped environments. Args: actions: Batched action to apply to the environment. Raises: ValueError: Invalid actions. Returns: Batch of observations, rewards, and done flags. """ for index, (env, action) in enumerate(zip(self._envs, actions)): if not env.action_space.contains(action): message = 'Invalid action at index {}: {}' raise ValueError(message.format(index, action)) if self._blocking: transitions = [env.step(action) for env, action in zip(self._envs, actions)] else: transitions = [env.step(action, blocking=False) for env, action in zip(self._envs, actions)] transitions = [transition() for transition in transitions] observs, rewards, dones, infos = zip(*transitions) observ = np.stack(observs) reward = np.stack(rewards) done = np.stack(dones) info = tuple(infos) return observ, reward, done, info def reset(self, indices=None): """Reset the environment and convert the resulting observation. Args: indices: The batch indices of environments to reset; defaults to all. Returns: Batch of observations. """ if indices is None: indices = np.arange(len(self._envs)) if self._blocking: observs = [self._envs[index].reset() for index in indices] else: observs = [self._envs[index].reset(blocking=False) for index in indices] observs = [observ() for observ in observs] observ = np.stack(observs) return observ def close(self): """Send close messages to the external process and join them.""" for env in self._envs: if hasattr(env, 'close'): env.close()