Hello TensorFlow community,
I’m trying to find a way in TF2 to use the tf.keras.layers.BatchNormalization layer in training mode (i.e. normalizing using the statistics of the current batch) but without updating the moving mean and variance (for some batches, not all).
In TF1, using tf.layers.batch_normalization, you could do something like
x = my_first_inputs # I want to use these data for updating moving statistics
y = my_second_inputs # I do not want to use these data for updating moving statistics
out_x = my_model(x, training=True)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
out_y = my_model(y, training=True)
train_op = gradient step to minimize loss
train_op = tf.group(*update_ops)
Does anyone have an idea of how to replicate this in TF2?