Batch dimension is None in custom loss function in TensorFlow 2

I’m training a model which inputs and outputs images with same shape (H, W, C).

My loss function is MSE over these images, but in another color space.

The color space conversion is defined by transform_space function, which takes and returns one image.

I’m inheriting tf.keras.losses.Loss to implement this loss.

The method call however takes images not one by one, but in batches of shape (None, H. W, C).

The problem is the first dimension of this batch is None.

I was trying to iterate through these batches, but got error iterating over tf.Tensor is not allowed.

So, how should I calculate my loss?

The reasons I can’t use a new color space as input and output for my model:

  • the model is using one of pretrained tf.keras.applications which works with RGB
  • reverse transformation can’t be done because part of information is lost during transformation

I’m using tf.distribute.MirroredStrategy if it matters.

# Takes an image of shape (H, W, C),
# converts it to a new color space
# and returns a new image with shape (H, W, C)
def transform_space(image):
  # ...color space transformation...
  return image_in_a_new_color_space

class MyCustomLoss(tf.keras.losses.Loss):

  def __init__(self):

    # The loss function is defined this way
    # due to the fact that I use "tf.distribute.MirroredStrategy"
    mse = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)
    self.loss_fn = lambda true, pred: tf.math.reduce_mean(mse(true, pred))

  def call(self, true_batch, pred_batch):

    # Since shape of true/pred_batch is (None, H, W, C)
    # and transform_space expects shape (H, W, C)
    # the following transformations are impossible:
    true_batch_transformed = transform_space(true_batch)
    pred_batch_transformed = transform_space(pred_batch)

    return self.loss_fn(true_batch_transformed, pred_batch_transformed)

Hi @Artem_Legotin

Welcome to the TensorFlow Forum!

Please provide minimal reproducible code along with the error log to replicate the error as it’s required more information to understand the issue. Meanwhile, you can refer to the How to define Loss function in Distributed training which might help you in this issue. Thank you.

