BatchNormalization in training mode without updating moving mean and variance?

Hello TensorFlow community,

I’m trying to find a way in TF2 to use the tf.keras.layers.BatchNormalization layer in training mode (i.e. normalizing using the statistics of the current batch) but without updating the moving mean and variance (for some batches, not all).

In TF1, using tf.layers.batch_normalization, you could do something like

x = my_first_inputs # I want to use these data for updating moving statistics
y = my_second_inputs # I do not want to use these data for updating moving statistics

out_x = my_model(x, training=True)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
out_y = my_model(y, training=True)

train_op = gradient step to minimize loss

with tf.control_dependencies([train_op]):
train_op = tf.group(*update_ops)

session.run(train_op)

Does anyone have an idea of how to replicate this in TF2?

Best,
Erik

Also if not strictly related to the batchnorm I see that TF agents is still manipulating the same API in TF2.0 using v1 namespace:

This is the migration guide:

Thank you for your answer.

Unfortunately I can’t really find a solution for my problem in the migration guide. All it says is that the moving statistics for BatchNorm will be updated automatically in TF2 when calling with “training=True”, which is what I don’t want.

I am also unfortunately not familiar enough with all the inner mechanics of TF2 to understand how the snippet from eager_utils.py helps me.

I think I finally found a fairly good solution to this. Posting in case anyone with the same problem finds this thread.

One cause of my original problem is that tf.keras.layers.BatchNormalization uses a custom behavior for layer.trainable = False. From the docs:

However, in the case of the BatchNormalization layer, setting trainable = False on the layer means that the layer will be subsequently run in inference mode (meaning that it will use the moving mean and the moving variance to normalize the current batch, rather than using the mean and variance of the current batch).

This behavior has been introduced in TensorFlow 2.0, in order to enable layer.trainable = False to produce the most commonly expected behavior in the convnet fine-tuning use case.

This behavior can be disabled by subclassing tf.keras.layers.BatchNormalization and overwriting the _get_training_value() method:

class MyBatchNorm(tf.keras.layers.BatchNormalization):

    def _get_training_value(self, training=None):
        if training is None:
            training = backend.learning_phase()
        if self._USE_V2_BEHAVIOR:
            if isinstance(training, int):
                training = bool(training)
            #if not self.trainable:
            #    # When the layer is not trainable, it overrides the value passed
            #    # from model.
            #    training = False
        return training    

Note that the custom behavior for layer.training is disabled by the commenting the four lines.

We can then use batch normalization with batch statistics but without updating the moving statistics using something like

model = MyBatchNorm()

model.trainable = False
out = model(x, training=True)

whereas the unmodified tf.keras.layers.BatchNormalization would use the moving statistics in this call.

1 Like