add a temp copy of TF agents (until the API stops changing or configs.py are included)
This commit is contained in:
@@ -24,15 +24,14 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import datetime
|
||||
import functools
|
||||
import os
|
||||
|
||||
import gym
|
||||
import tensorflow as tf
|
||||
|
||||
from agents import tools
|
||||
from . import tools
|
||||
from . import configs
|
||||
from agents.scripts import utility
|
||||
from . import utility
|
||||
|
||||
|
||||
def _create_environment(config):
|
||||
@@ -73,7 +72,7 @@ def _define_loop(graph, logdir, train_steps, eval_steps):
|
||||
graph.force_reset)
|
||||
loop.add_phase(
|
||||
'train', graph.done, graph.score, graph.summary, train_steps,
|
||||
report_every=None,
|
||||
report_every=train_steps,
|
||||
log_every=train_steps // 2,
|
||||
checkpoint_every=None,
|
||||
feed={graph.is_training: True})
|
||||
@@ -100,9 +99,6 @@ def train(config, env_processes):
|
||||
Evaluation scores.
|
||||
"""
|
||||
tf.reset_default_graph()
|
||||
with config.unlocked:
|
||||
config.policy_optimizer = getattr(tf.train, config.policy_optimizer)
|
||||
config.value_optimizer = getattr(tf.train, config.value_optimizer)
|
||||
if config.update_every % config.num_agents:
|
||||
tf.logging.warn('Number of agents should divide episodes per update.')
|
||||
with tf.device('/cpu:0'):
|
||||
|
||||
Reference in New Issue
Block a user