add a simple DDPG agent and a policy
This commit is contained in:
@@ -30,19 +30,17 @@ class SimpleAgent():
|
|||||||
|
|
||||||
def _build(self):
|
def _build(self):
|
||||||
self._agent_net = actor_net.ActorNetwork(self._actor_layer_size, self._action_size)
|
self._agent_net = actor_net.ActorNetwork(self._actor_layer_size, self._action_size)
|
||||||
self._o_t = tf.placeholder(tf.float32, (31,))
|
self._obs = tf.placeholder(tf.float32, (31,))
|
||||||
with tf.name_scope('Act'):
|
with tf.name_scope('Act'):
|
||||||
batch_o_t = snt.nest.pack_iterable_as(
|
batch_obs = snt.nest.pack_iterable_as(self._obs,
|
||||||
self._o_t,
|
snt.nest.map(lambda x: tf.expand_dims(x, 0),
|
||||||
snt.nest.map(
|
snt.nest.flatten_iterable(self._obs)))
|
||||||
lambda x: tf.expand_dims(x, 0),
|
self._action = self._agent_net(batch_obs)
|
||||||
snt.nest.flatten_iterable(self._o_t)))
|
|
||||||
self._action = self._agent_net(batch_o_t)
|
|
||||||
saver = tf.train.Saver()
|
saver = tf.train.Saver()
|
||||||
saver.restore(
|
saver.restore(
|
||||||
sess=self._session,
|
sess=self._session,
|
||||||
save_path=self._ckpt_path)
|
save_path=self._ckpt_path)
|
||||||
|
|
||||||
def __call__(self, observation):
|
def __call__(self, observation):
|
||||||
out_action = self._session.run(self._action, feed_dict={self._o_t: observation})
|
out_action = self._session.run(self._action, feed_dict={self._obs: observation})
|
||||||
return out_action[0]
|
return out_action[0]
|
||||||
|
|||||||
Reference in New Issue
Block a user