diff --git a/examples/pybullet/gym/agents/actor_net.py b/examples/pybullet/gym/agents/actor_net.py deleted file mode 100644 index ac6aaff8a..000000000 --- a/examples/pybullet/gym/agents/actor_net.py +++ /dev/null @@ -1,21 +0,0 @@ -"""An actor network.""" -import tensorflow as tf -import sonnet as snt - -class ActorNetwork(snt.AbstractModule): - """An actor network as a sonnet Module.""" - - def __init__(self, layer_sizes, action_size, name='target_actor'): - super(ActorNetwork, self).__init__(name=name) - self._layer_sizes = layer_sizes - self._action_size = action_size - - def _build(self, inputs): - state = inputs - for output_size in self._layer_sizes: - state = snt.Linear(output_size)(state) - state = tf.nn.relu(state) - - action = tf.tanh( - snt.Linear(self._action_size, name='action')(state)) - return action diff --git a/examples/pybullet/gym/agents/simpleAgent.py b/examples/pybullet/gym/agents/simpleAgent.py deleted file mode 100644 index 08a4cf1fa..000000000 --- a/examples/pybullet/gym/agents/simpleAgent.py +++ /dev/null @@ -1,46 +0,0 @@ -"""Loads a DDPG agent without too much external dependencies -""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os -import collections -import numpy as np -import tensorflow as tf - -import sonnet as snt -from agents import actor_net - -class SimpleAgent(): - def __init__( - self, - session, - ckpt_path, - actor_layer_size, - observation_size=(31,), - action_size=8, - ): - self._ckpt_path = ckpt_path - self._actor_layer_size = actor_layer_size - self._observation_size = observation_size - self._action_size = action_size - self._session = session - self._build() - - def _build(self): - self._agent_net = actor_net.ActorNetwork(self._actor_layer_size, self._action_size) - self._obs = tf.placeholder(tf.float32, (31,)) - with tf.name_scope('Act'): - batch_obs = snt.nest.pack_iterable_as(self._obs, - snt.nest.map(lambda x: tf.expand_dims(x, 0), - snt.nest.flatten_iterable(self._obs))) - self._action = self._agent_net(batch_obs) - saver = tf.train.Saver() - saver.restore( - sess=self._session, - save_path=self._ckpt_path) - - def __call__(self, observation): - out_action = self._session.run(self._action, feed_dict={self._obs: observation}) - return out_action[0] diff --git a/examples/pybullet/gym/agents/simplerAgent.py b/examples/pybullet/gym/agents/simplerAgent.py new file mode 100644 index 000000000..4f12f04db --- /dev/null +++ b/examples/pybullet/gym/agents/simplerAgent.py @@ -0,0 +1,36 @@ +"""Loads a DDPG agent without too much external dependencies +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import collections +import numpy as np +import tensorflow as tf +import pdb + +class SimplerAgent(): + def __init__( + self, + session, + ckpt_path, + observation_dim=31 + ): + self._ckpt_path = ckpt_path + self._session = session + self._observation_dim = observation_dim + self._build() + + def _build(self): + saver = tf.train.import_meta_graph(self._ckpt_path + '.meta') + saver.restore( + sess=self._session, + save_path=self._ckpt_path) + self._action = tf.get_collection('action_op')[0] + self._obs = tf.get_collection('observation_placeholder')[0] + + def __call__(self, observation): + feed_dict={self._obs: observation} + out_action = self._session.run(self._action, feed_dict=feed_dict) + return out_action[0] diff --git a/examples/pybullet/gym/data/agent/tf_graph_data/checkpoint b/examples/pybullet/gym/data/agent/tf_graph_data/checkpoint index 2b11c1281..72cc2c323 100644 --- a/examples/pybullet/gym/data/agent/tf_graph_data/checkpoint +++ b/examples/pybullet/gym/data/agent/tf_graph_data/checkpoint @@ -1,2 +1,2 @@ -model_checkpoint_path: "/cns/ij-d/home/jietan/persistent/minitaur/minitaur_vizier_3_153645653/Bullet/MinitaurSimEnv/28158/0003600000/agent/tf_graph_data/tf_graph_data.ckpt" -all_model_checkpoint_paths: "/cns/ij-d/home/jietan/persistent/minitaur/minitaur_vizier_3_153645653/Bullet/MinitaurSimEnv/28158/0003600000/agent/tf_graph_data/tf_graph_data.ckpt" +model_checkpoint_path: "tf_graph_data_converted.ckpt-0" +all_model_checkpoint_paths: "tf_graph_data_converted.ckpt-0" diff --git a/examples/pybullet/gym/data/agent/tf_graph_data/tf_graph_data.ckpt.data-00000-of-00001 b/examples/pybullet/gym/data/agent/tf_graph_data/tf_graph_data.ckpt.data-00000-of-00001 deleted file mode 100644 index b25aa2872..000000000 Binary files a/examples/pybullet/gym/data/agent/tf_graph_data/tf_graph_data.ckpt.data-00000-of-00001 and /dev/null differ diff --git a/examples/pybullet/gym/data/agent/tf_graph_data/tf_graph_data.ckpt.index b/examples/pybullet/gym/data/agent/tf_graph_data/tf_graph_data.ckpt.index deleted file mode 100644 index 8abcb6ea5..000000000 Binary files a/examples/pybullet/gym/data/agent/tf_graph_data/tf_graph_data.ckpt.index and /dev/null differ diff --git a/examples/pybullet/gym/data/agent/tf_graph_data/tf_graph_data.ckpt.meta b/examples/pybullet/gym/data/agent/tf_graph_data/tf_graph_data.ckpt.meta deleted file mode 100644 index e1369a3d1..000000000 Binary files a/examples/pybullet/gym/data/agent/tf_graph_data/tf_graph_data.ckpt.meta and /dev/null differ diff --git a/examples/pybullet/gym/data/agent/tf_graph_data/tf_graph_data_converted.ckpt-0.data-00000-of-00001 b/examples/pybullet/gym/data/agent/tf_graph_data/tf_graph_data_converted.ckpt-0.data-00000-of-00001 new file mode 100644 index 000000000..4d4eb02c3 Binary files /dev/null and b/examples/pybullet/gym/data/agent/tf_graph_data/tf_graph_data_converted.ckpt-0.data-00000-of-00001 differ diff --git a/examples/pybullet/gym/data/agent/tf_graph_data/tf_graph_data_converted.ckpt-0.index b/examples/pybullet/gym/data/agent/tf_graph_data/tf_graph_data_converted.ckpt-0.index new file mode 100644 index 000000000..9f923be82 Binary files /dev/null and b/examples/pybullet/gym/data/agent/tf_graph_data/tf_graph_data_converted.ckpt-0.index differ diff --git a/examples/pybullet/gym/data/agent/tf_graph_data/tf_graph_data_converted.ckpt-0.meta b/examples/pybullet/gym/data/agent/tf_graph_data/tf_graph_data_converted.ckpt-0.meta new file mode 100644 index 000000000..2a70ea383 Binary files /dev/null and b/examples/pybullet/gym/data/agent/tf_graph_data/tf_graph_data_converted.ckpt-0.meta differ diff --git a/examples/pybullet/gym/minitaurGymEnvTest.py b/examples/pybullet/gym/minitaurGymEnvTest.py index ff5db8500..3bbabd9a4 100644 --- a/examples/pybullet/gym/minitaurGymEnvTest.py +++ b/examples/pybullet/gym/minitaurGymEnvTest.py @@ -10,7 +10,7 @@ import numpy as np import tensorflow as tf from envs.bullet.minitaurGymEnv import MinitaurGymEnv -from agents import simpleAgent +from agents import simplerAgent def testSinePolicy(): """Tests sine policy @@ -53,17 +53,14 @@ def testDDPGPolicy(): environment = MinitaurGymEnv(render=True) sum_reward = 0 steps = 1000 - ckpt_path = 'data/agent/tf_graph_data/tf_graph_data.ckpt' + ckpt_path = 'data/agent/tf_graph_data/tf_graph_data_converted.ckpt-0' observation_shape = (31,) action_size = 8 actor_layer_sizes = (100, 181) n_steps = 0 tf.reset_default_graph() with tf.Session() as session: - agent = simpleAgent.SimpleAgent(session, ckpt_path, - actor_layer_sizes, - observation_size=observation_shape, - action_size=action_size) + agent = simplerAgent.SimplerAgent(session, ckpt_path) state = environment.reset() action = agent(state) for _ in range(steps):