remove sonnet dependency
This commit is contained in:
@@ -1,21 +0,0 @@
|
|||||||
"""An actor network."""
|
|
||||||
import tensorflow as tf
|
|
||||||
import sonnet as snt
|
|
||||||
|
|
||||||
class ActorNetwork(snt.AbstractModule):
|
|
||||||
"""An actor network as a sonnet Module."""
|
|
||||||
|
|
||||||
def __init__(self, layer_sizes, action_size, name='target_actor'):
|
|
||||||
super(ActorNetwork, self).__init__(name=name)
|
|
||||||
self._layer_sizes = layer_sizes
|
|
||||||
self._action_size = action_size
|
|
||||||
|
|
||||||
def _build(self, inputs):
|
|
||||||
state = inputs
|
|
||||||
for output_size in self._layer_sizes:
|
|
||||||
state = snt.Linear(output_size)(state)
|
|
||||||
state = tf.nn.relu(state)
|
|
||||||
|
|
||||||
action = tf.tanh(
|
|
||||||
snt.Linear(self._action_size, name='action')(state))
|
|
||||||
return action
|
|
||||||
@@ -1,46 +0,0 @@
|
|||||||
"""Loads a DDPG agent without too much external dependencies
|
|
||||||
"""
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import os
|
|
||||||
import collections
|
|
||||||
import numpy as np
|
|
||||||
import tensorflow as tf
|
|
||||||
|
|
||||||
import sonnet as snt
|
|
||||||
from agents import actor_net
|
|
||||||
|
|
||||||
class SimpleAgent():
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
session,
|
|
||||||
ckpt_path,
|
|
||||||
actor_layer_size,
|
|
||||||
observation_size=(31,),
|
|
||||||
action_size=8,
|
|
||||||
):
|
|
||||||
self._ckpt_path = ckpt_path
|
|
||||||
self._actor_layer_size = actor_layer_size
|
|
||||||
self._observation_size = observation_size
|
|
||||||
self._action_size = action_size
|
|
||||||
self._session = session
|
|
||||||
self._build()
|
|
||||||
|
|
||||||
def _build(self):
|
|
||||||
self._agent_net = actor_net.ActorNetwork(self._actor_layer_size, self._action_size)
|
|
||||||
self._obs = tf.placeholder(tf.float32, (31,))
|
|
||||||
with tf.name_scope('Act'):
|
|
||||||
batch_obs = snt.nest.pack_iterable_as(self._obs,
|
|
||||||
snt.nest.map(lambda x: tf.expand_dims(x, 0),
|
|
||||||
snt.nest.flatten_iterable(self._obs)))
|
|
||||||
self._action = self._agent_net(batch_obs)
|
|
||||||
saver = tf.train.Saver()
|
|
||||||
saver.restore(
|
|
||||||
sess=self._session,
|
|
||||||
save_path=self._ckpt_path)
|
|
||||||
|
|
||||||
def __call__(self, observation):
|
|
||||||
out_action = self._session.run(self._action, feed_dict={self._obs: observation})
|
|
||||||
return out_action[0]
|
|
||||||
36
examples/pybullet/gym/agents/simplerAgent.py
Normal file
36
examples/pybullet/gym/agents/simplerAgent.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
"""Loads a DDPG agent without too much external dependencies
|
||||||
|
"""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
import collections
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
import pdb
|
||||||
|
|
||||||
|
class SimplerAgent():
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
session,
|
||||||
|
ckpt_path,
|
||||||
|
observation_dim=31
|
||||||
|
):
|
||||||
|
self._ckpt_path = ckpt_path
|
||||||
|
self._session = session
|
||||||
|
self._observation_dim = observation_dim
|
||||||
|
self._build()
|
||||||
|
|
||||||
|
def _build(self):
|
||||||
|
saver = tf.train.import_meta_graph(self._ckpt_path + '.meta')
|
||||||
|
saver.restore(
|
||||||
|
sess=self._session,
|
||||||
|
save_path=self._ckpt_path)
|
||||||
|
self._action = tf.get_collection('action_op')[0]
|
||||||
|
self._obs = tf.get_collection('observation_placeholder')[0]
|
||||||
|
|
||||||
|
def __call__(self, observation):
|
||||||
|
feed_dict={self._obs: observation}
|
||||||
|
out_action = self._session.run(self._action, feed_dict=feed_dict)
|
||||||
|
return out_action[0]
|
||||||
@@ -1,2 +1,2 @@
|
|||||||
model_checkpoint_path: "/cns/ij-d/home/jietan/persistent/minitaur/minitaur_vizier_3_153645653/Bullet/MinitaurSimEnv/28158/0003600000/agent/tf_graph_data/tf_graph_data.ckpt"
|
model_checkpoint_path: "tf_graph_data_converted.ckpt-0"
|
||||||
all_model_checkpoint_paths: "/cns/ij-d/home/jietan/persistent/minitaur/minitaur_vizier_3_153645653/Bullet/MinitaurSimEnv/28158/0003600000/agent/tf_graph_data/tf_graph_data.ckpt"
|
all_model_checkpoint_paths: "tf_graph_data_converted.ckpt-0"
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -10,7 +10,7 @@ import numpy as np
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from envs.bullet.minitaurGymEnv import MinitaurGymEnv
|
from envs.bullet.minitaurGymEnv import MinitaurGymEnv
|
||||||
from agents import simpleAgent
|
from agents import simplerAgent
|
||||||
|
|
||||||
def testSinePolicy():
|
def testSinePolicy():
|
||||||
"""Tests sine policy
|
"""Tests sine policy
|
||||||
@@ -53,17 +53,14 @@ def testDDPGPolicy():
|
|||||||
environment = MinitaurGymEnv(render=True)
|
environment = MinitaurGymEnv(render=True)
|
||||||
sum_reward = 0
|
sum_reward = 0
|
||||||
steps = 1000
|
steps = 1000
|
||||||
ckpt_path = 'data/agent/tf_graph_data/tf_graph_data.ckpt'
|
ckpt_path = 'data/agent/tf_graph_data/tf_graph_data_converted.ckpt-0'
|
||||||
observation_shape = (31,)
|
observation_shape = (31,)
|
||||||
action_size = 8
|
action_size = 8
|
||||||
actor_layer_sizes = (100, 181)
|
actor_layer_sizes = (100, 181)
|
||||||
n_steps = 0
|
n_steps = 0
|
||||||
tf.reset_default_graph()
|
tf.reset_default_graph()
|
||||||
with tf.Session() as session:
|
with tf.Session() as session:
|
||||||
agent = simpleAgent.SimpleAgent(session, ckpt_path,
|
agent = simplerAgent.SimplerAgent(session, ckpt_path)
|
||||||
actor_layer_sizes,
|
|
||||||
observation_size=observation_shape,
|
|
||||||
action_size=action_size)
|
|
||||||
state = environment.reset()
|
state = environment.reset()
|
||||||
action = agent(state)
|
action = agent(state)
|
||||||
for _ in range(steps):
|
for _ in range(steps):
|
||||||
|
|||||||
Reference in New Issue
Block a user