Custom Training Include dummy axis for loss

Hi everyone!

I have a small question: In the custom training tutorial, the “Caution” alert/box suggests the following solution to pointwise losses (such as BCE) during custom training:

“For pointwise losses like losses.mean_squared_erroror losses.binary_crossentropy include a dummy axis so that [batch, W, H, 1] is reduced to [batch, W, H] .”

However, I’m not quite sure where to place the extra dummy dimension (using tf.expand_dims() I assume). Inside the calculation of the loss function or inside the processing of the batches?

To me, the bullet point above the caution box suggests adding another dimension to the predictions (so inside the loss calculation). Does this imply that we need to expand the dims of the labels as well?

Thank you!


Hi @nmwitzig

Welcome to the TensorFlow Forum!

Please provide minimal reproducible code or the tutorial link to check and understand the issue. Thank you.