BatchNormalization in training mode without updating moving mean and variance?

Hello TensorFlow community,

I’m trying to find a way in TF2 to use the tf.keras.layers.BatchNormalization layer in training mode (i.e. normalizing using the statistics of the current batch) but without updating the moving mean and variance (for some batches, not all).

In TF1, using tf.layers.batch_normalization, you could do something like

x = my_first_inputs # I want to use these data for updating moving statistics
y = my_second_inputs # I do not want to use these data for updating moving statistics

out_x = my_model(x, training=True)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
out_y = my_model(y, training=True)

train_op = gradient step to minimize loss

with tf.control_dependencies([train_op]):
train_op = tf.group(*update_ops)

session.run(train_op)

Does anyone have an idea of how to replicate this in TF2?

Best,
Erik

Also if not strictly related to the batchnorm I see that TF agents is still manipulating the same API in TF2.0 using v1 namespace:

This is the migration guide:

Thank you for your answer.

Unfortunately I can’t really find a solution for my problem in the migration guide. All it says is that the moving statistics for BatchNorm will be updated automatically in TF2 when calling with “training=True”, which is what I don’t want.

I am also unfortunately not familiar enough with all the inner mechanics of TF2 to understand how the snippet from eager_utils.py helps me.