# Copyright 2017 The TensorFlow Agents Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Normalize tensors based on streaming estimates of mean and variance.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import tensorflow as tf class StreamingNormalize(object): """Normalize tensors based on streaming estimates of mean and variance.""" def __init__( self, template, center=True, scale=True, clip=10, name='normalize'): """Normalize tensors based on streaming estimates of mean and variance. Centering the value, scaling it by the standard deviation, and clipping outlier values are optional. Args: template: Example tensor providing shape and dtype of the vaule to track. center: Python boolean indicating whether to subtract mean from values. scale: Python boolean indicating whether to scale values by stddev. clip: If and when to clip normalized values. name: Parent scope of operations provided by this class. """ self._center = center self._scale = scale self._clip = clip self._name = name with tf.name_scope(name): self._count = tf.Variable(0, False) self._mean = tf.Variable(tf.zeros_like(template), False) self._var_sum = tf.Variable(tf.zeros_like(template), False) def transform(self, value): """Normalize a single or batch tensor. Applies the activated transformations in the constructor using current estimates of mean and variance. Args: value: Batch or single value tensor. Returns: Normalized batch or single value tensor. """ with tf.name_scope(self._name + '/transform'): no_batch_dim = value.shape.ndims == self._mean.shape.ndims if no_batch_dim: # Add a batch dimension if necessary. value = value[None, ...] if self._center: value -= self._mean[None, ...] if self._scale: # We cannot scale before seeing at least two samples. value /= tf.cond( self._count > 1, lambda: self._std() + 1e-8, lambda: tf.ones_like(self._var_sum))[None] if self._clip: value = tf.clip_by_value(value, -self._clip, self._clip) # Remove batch dimension if necessary. if no_batch_dim: value = value[0] return tf.check_numerics(value, 'value') def update(self, value): """Update the mean and variance estimates. Args: value: Batch or single value tensor. Returns: Summary tensor. """ with tf.name_scope(self._name + '/update'): if value.shape.ndims == self._mean.shape.ndims: # Add a batch dimension if necessary. value = value[None, ...] count = tf.shape(value)[0] with tf.control_dependencies([self._count.assign_add(count)]): step = tf.cast(self._count, tf.float32) mean_delta = tf.reduce_sum(value - self._mean[None, ...], 0) new_mean = self._mean + mean_delta / step new_mean = tf.cond(self._count > 1, lambda: new_mean, lambda: value[0]) var_delta = ( value - self._mean[None, ...]) * (value - new_mean[None, ...]) new_var_sum = self._var_sum + tf.reduce_sum(var_delta, 0) with tf.control_dependencies([new_mean, new_var_sum]): update = self._mean.assign(new_mean), self._var_sum.assign(new_var_sum) with tf.control_dependencies(update): if value.shape.ndims == 1: value = tf.reduce_mean(value) return self._summary('value', tf.reduce_mean(value)) def reset(self): """Reset the estimates of mean and variance. Resets the full state of this class. Returns: Operation. """ with tf.name_scope(self._name + '/reset'): return tf.group( self._count.assign(0), self._mean.assign(tf.zeros_like(self._mean)), self._var_sum.assign(tf.zeros_like(self._var_sum))) def summary(self): """Summary string of mean and standard deviation. Returns: Summary tensor. """ with tf.name_scope(self._name + '/summary'): mean_summary = tf.cond( self._count > 0, lambda: self._summary('mean', self._mean), str) std_summary = tf.cond( self._count > 1, lambda: self._summary('stddev', self._std()), str) return tf.summary.merge([mean_summary, std_summary]) def _std(self): """Computes the current estimate of the standard deviation. Note that the standard deviation is not defined until at least two samples were seen. Returns: Tensor of current variance. """ variance = tf.cond( self._count > 1, lambda: self._var_sum / tf.cast(self._count - 1, tf.float32), lambda: tf.ones_like(self._var_sum) * float('nan')) # The epsilon corrects for small negative variance values caused by # the algorithm. It was empirically chosen to work with all environments # tested. return tf.sqrt(variance + 1e-4) def _summary(self, name, tensor): """Create a scalar or histogram summary matching the rank of the tensor. Args: name: Name for the summary. tensor: Tensor to summarize. Returns: Summary tensor. """ if tensor.shape.ndims == 0: return tf.summary.scalar(name, tensor) else: return tf.summary.histogram(name, tensor)