add yapf style and apply yapf to format all Python files

This recreates pull request #2192
This commit is contained in:
Erwin Coumans
2019-04-27 07:31:15 -07:00
parent c591735042
commit ef9570c315
347 changed files with 70304 additions and 22752 deletions

View File

@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Script to train a batch reinforcement learning algorithm.
Command line:
@@ -67,21 +66,25 @@ def _define_loop(graph, logdir, train_steps, eval_steps):
Returns:
Loop object.
"""
loop = tools.Loop(
logdir, graph.step, graph.should_log, graph.do_report,
graph.force_reset)
loop.add_phase(
'train', graph.done, graph.score, graph.summary, train_steps,
report_every=train_steps,
log_every=train_steps // 2,
checkpoint_every=None,
feed={graph.is_training: True})
loop.add_phase(
'eval', graph.done, graph.score, graph.summary, eval_steps,
report_every=eval_steps,
log_every=eval_steps // 2,
checkpoint_every=10 * eval_steps,
feed={graph.is_training: False})
loop = tools.Loop(logdir, graph.step, graph.should_log, graph.do_report, graph.force_reset)
loop.add_phase('train',
graph.done,
graph.score,
graph.summary,
train_steps,
report_every=train_steps,
log_every=train_steps // 2,
checkpoint_every=None,
feed={graph.is_training: True})
loop.add_phase('eval',
graph.done,
graph.score,
graph.summary,
eval_steps,
report_every=eval_steps,
log_every=eval_steps // 2,
checkpoint_every=10 * eval_steps,
feed={graph.is_training: False})
return loop
@@ -102,18 +105,13 @@ def train(config, env_processes):
if config.update_every % config.num_agents:
tf.logging.warn('Number of agents should divide episodes per update.')
with tf.device('/cpu:0'):
batch_env = utility.define_batch_env(
lambda: _create_environment(config),
config.num_agents, env_processes)
graph = utility.define_simulation_graph(
batch_env, config.algorithm, config)
loop = _define_loop(
graph, config.logdir,
config.update_every * config.max_length,
config.eval_episodes * config.max_length)
total_steps = int(
config.steps / config.update_every *
(config.update_every + config.eval_episodes))
batch_env = utility.define_batch_env(lambda: _create_environment(config), config.num_agents,
env_processes)
graph = utility.define_simulation_graph(batch_env, config.algorithm, config)
loop = _define_loop(graph, config.logdir, config.update_every * config.max_length,
config.eval_episodes * config.max_length)
total_steps = int(config.steps / config.update_every *
(config.update_every + config.eval_episodes))
# Exclude episode related variables since the Python state of environments is
# not checkpointed and thus new episodes start after resuming.
saver = utility.define_saver(exclude=(r'.*_temporary/.*',))
@@ -131,8 +129,8 @@ def main(_):
utility.set_up_logging()
if not FLAGS.config:
raise KeyError('You must specify a configuration.')
logdir = FLAGS.logdir and os.path.expanduser(os.path.join(
FLAGS.logdir, '{}-{}'.format(FLAGS.timestamp, FLAGS.config)))
logdir = FLAGS.logdir and os.path.expanduser(
os.path.join(FLAGS.logdir, '{}-{}'.format(FLAGS.timestamp, FLAGS.config)))
try:
config = utility.load_config(logdir)
except IOError:
@@ -144,16 +142,11 @@ def main(_):
if __name__ == '__main__':
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string(
'logdir', None,
'Base directory to store logs.')
tf.app.flags.DEFINE_string(
'timestamp', datetime.datetime.now().strftime('%Y%m%dT%H%M%S'),
'Sub directory to store logs.')
tf.app.flags.DEFINE_string(
'config', None,
'Configuration to execute.')
tf.app.flags.DEFINE_boolean(
'env_processes', True,
'Step environments in separate processes to circumvent the GIL.')
tf.app.flags.DEFINE_string('logdir', None, 'Base directory to store logs.')
tf.app.flags.DEFINE_string('timestamp',
datetime.datetime.now().strftime('%Y%m%dT%H%M%S'),
'Sub directory to store logs.')
tf.app.flags.DEFINE_string('config', None, 'Configuration to execute.')
tf.app.flags.DEFINE_boolean('env_processes', True,
'Step environments in separate processes to circumvent the GIL.')
tf.app.run()