add yapf style and apply yapf to format all Python files
This recreates pull request #2192
This commit is contained in:
@@ -11,7 +11,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Memory that stores episodes."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
@@ -43,10 +42,9 @@ class EpisodeMemory(object):
|
||||
self._scope = var_scope
|
||||
self._length = tf.Variable(tf.zeros(capacity, tf.int32), False)
|
||||
self._buffers = [
|
||||
tf.Variable(tf.zeros(
|
||||
[capacity, max_length] + elem.shape.as_list(),
|
||||
elem.dtype), False)
|
||||
for elem in template]
|
||||
tf.Variable(tf.zeros([capacity, max_length] + elem.shape.as_list(), elem.dtype), False)
|
||||
for elem in template
|
||||
]
|
||||
|
||||
def length(self, rows=None):
|
||||
"""Tensor holding the current length of episodes.
|
||||
@@ -72,13 +70,11 @@ class EpisodeMemory(object):
|
||||
"""
|
||||
rows = tf.range(self._capacity) if rows is None else rows
|
||||
assert rows.shape.ndims == 1
|
||||
assert_capacity = tf.assert_less(
|
||||
rows, self._capacity,
|
||||
message='capacity exceeded')
|
||||
assert_capacity = tf.assert_less(rows, self._capacity, message='capacity exceeded')
|
||||
with tf.control_dependencies([assert_capacity]):
|
||||
assert_max_length = tf.assert_less(
|
||||
tf.gather(self._length, rows), self._max_length,
|
||||
message='max length exceeded')
|
||||
assert_max_length = tf.assert_less(tf.gather(self._length, rows),
|
||||
self._max_length,
|
||||
message='max length exceeded')
|
||||
append_ops = []
|
||||
with tf.control_dependencies([assert_max_length]):
|
||||
for buffer_, elements in zip(self._buffers, transitions):
|
||||
@@ -86,8 +82,7 @@ class EpisodeMemory(object):
|
||||
indices = tf.stack([rows, timestep], 1)
|
||||
append_ops.append(tf.scatter_nd_update(buffer_, indices, elements))
|
||||
with tf.control_dependencies(append_ops):
|
||||
episode_mask = tf.reduce_sum(tf.one_hot(
|
||||
rows, self._capacity, dtype=tf.int32), 0)
|
||||
episode_mask = tf.reduce_sum(tf.one_hot(rows, self._capacity, dtype=tf.int32), 0)
|
||||
return self._length.assign_add(episode_mask)
|
||||
|
||||
def replace(self, episodes, length, rows=None):
|
||||
@@ -103,11 +98,11 @@ class EpisodeMemory(object):
|
||||
"""
|
||||
rows = tf.range(self._capacity) if rows is None else rows
|
||||
assert rows.shape.ndims == 1
|
||||
assert_capacity = tf.assert_less(
|
||||
rows, self._capacity, message='capacity exceeded')
|
||||
assert_capacity = tf.assert_less(rows, self._capacity, message='capacity exceeded')
|
||||
with tf.control_dependencies([assert_capacity]):
|
||||
assert_max_length = tf.assert_less_equal(
|
||||
length, self._max_length, message='max length exceeded')
|
||||
assert_max_length = tf.assert_less_equal(length,
|
||||
self._max_length,
|
||||
message='max length exceeded')
|
||||
replace_ops = []
|
||||
with tf.control_dependencies([assert_max_length]):
|
||||
for buffer_, elements in zip(self._buffers, episodes):
|
||||
|
||||
Reference in New Issue
Block a user