Hello TensorFlow community,

I’m trying to find a way in TF2 to use the tf.keras.layers.BatchNormalization layer in training mode (i.e. normalizing using the statistics of the current batch) but without updating the moving mean and variance (for some batches, not all).

In TF1, using tf.layers.batch_normalization, you could do something like

x = my_first_inputs # I want to use these data for updating moving statistics

y = my_second_inputs # I do not want to use these data for updating moving statistics

out_x = my_model(x, training=True)

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

out_y = my_model(y, training=True)

…

train_op = **gradient step to minimize loss**

…

with tf.control_dependencies([train_op]):

train_op = tf.group(*update_ops)

session.run(train_op)

Does anyone have an idea of how to replicate this in TF2?

Best,

Erik