# Copyright 2017 The TensorFlow Agents Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. r"""Script to train a batch reinforcement learning algorithm. Command line: python3 -m agents.scripts.train --logdir=/path/to/logdir --config=pendulum """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import datetime import os import gym import tensorflow as tf from . import tools from . import configs from . import utility def _create_environment(config): """Constructor for an instance of the environment. Args: config: Object providing configurations via attributes. Returns: Wrapped OpenAI Gym environment. """ if isinstance(config.env, str): env = gym.make(config.env) else: env = config.env() if config.max_length: env = tools.wrappers.LimitDuration(env, config.max_length) env = tools.wrappers.RangeNormalize(env) env = tools.wrappers.ClipAction(env) env = tools.wrappers.ConvertTo32Bit(env) return env def _define_loop(graph, logdir, train_steps, eval_steps): """Create and configure a training loop with training and evaluation phases. Args: graph: Object providing graph elements via attributes. logdir: Log directory for storing checkpoints and summaries. train_steps: Number of training steps per epoch. eval_steps: Number of evaluation steps per epoch. Returns: Loop object. """ loop = tools.Loop( logdir, graph.step, graph.should_log, graph.do_report, graph.force_reset) loop.add_phase( 'train', graph.done, graph.score, graph.summary, train_steps, report_every=train_steps, log_every=train_steps // 2, checkpoint_every=None, feed={graph.is_training: True}) loop.add_phase( 'eval', graph.done, graph.score, graph.summary, eval_steps, report_every=eval_steps, log_every=eval_steps // 2, checkpoint_every=10 * eval_steps, feed={graph.is_training: False}) return loop def train(config, env_processes): """Training and evaluation entry point yielding scores. Resolves some configuration attributes, creates environments, graph, and training loop. By default, assigns all operations to the CPU. Args: config: Object providing configurations via attributes. env_processes: Whether to step environments in separate processes. Yields: Evaluation scores. """ tf.reset_default_graph() if config.update_every % config.num_agents: tf.logging.warn('Number of agents should divide episodes per update.') with tf.device('/cpu:0'): batch_env = utility.define_batch_env( lambda: _create_environment(config), config.num_agents, env_processes) graph = utility.define_simulation_graph( batch_env, config.algorithm, config) loop = _define_loop( graph, config.logdir, config.update_every * config.max_length, config.eval_episodes * config.max_length) total_steps = int( config.steps / config.update_every * (config.update_every + config.eval_episodes)) # Exclude episode related variables since the Python state of environments is # not checkpointed and thus new episodes start after resuming. saver = utility.define_saver(exclude=(r'.*_temporary/.*',)) sess_config = tf.ConfigProto(allow_soft_placement=True) sess_config.gpu_options.allow_growth = True with tf.Session(config=sess_config) as sess: utility.initialize_variables(sess, saver, config.logdir) for score in loop.run(sess, saver, total_steps): yield score batch_env.close() def main(_): """Create or load configuration and launch the trainer.""" utility.set_up_logging() if not FLAGS.config: raise KeyError('You must specify a configuration.') logdir = FLAGS.logdir and os.path.expanduser(os.path.join( FLAGS.logdir, '{}-{}'.format(FLAGS.timestamp, FLAGS.config))) try: config = utility.load_config(logdir) except IOError: config = tools.AttrDict(getattr(configs, FLAGS.config)()) config = utility.save_config(config, logdir) for score in train(config, FLAGS.env_processes): tf.logging.info('Score {}.'.format(score)) if __name__ == '__main__': FLAGS = tf.app.flags.FLAGS tf.app.flags.DEFINE_string( 'logdir', None, 'Base directory to store logs.') tf.app.flags.DEFINE_string( 'timestamp', datetime.datetime.now().strftime('%Y%m%dT%H%M%S'), 'Sub directory to store logs.') tf.app.flags.DEFINE_string( 'config', None, 'Configuration to execute.') tf.app.flags.DEFINE_boolean( 'env_processes', True, 'Step environments in separate processes to circumvent the GIL.') tf.app.run()