I’m training a model which inputs and outputs images with same shape
(H, W, C).
My loss function is MSE over these images, but in another color space.
The color space conversion is defined by
transform_space function, which takes and returns one image.
tf.keras.losses.Loss to implement this loss.
call however takes images not one by one, but in batches of shape
(None, H. W, C).
The problem is the first dimension of this batch is
I was trying to iterate through these batches, but got error
iterating over tf.Tensor is not allowed.
So, how should I calculate my loss?
The reasons I can’t use a new color space as input and output for my model:
- the model is using one of pretrained
tf.keras.applicationswhich works with
- reverse transformation can’t be done because part of information is lost during transformation
tf.distribute.MirroredStrategy if it matters.
# Takes an image of shape (H, W, C), # converts it to a new color space # and returns a new image with shape (H, W, C) def transform_space(image): # ...color space transformation... return image_in_a_new_color_space class MyCustomLoss(tf.keras.losses.Loss): def __init__(self): super().__init__() # The loss function is defined this way # due to the fact that I use "tf.distribute.MirroredStrategy" mse = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE) self.loss_fn = lambda true, pred: tf.math.reduce_mean(mse(true, pred)) def call(self, true_batch, pred_batch): # Since shape of true/pred_batch is (None, H, W, C) # and transform_space expects shape (H, W, C) # the following transformations are impossible: true_batch_transformed = transform_space(true_batch) pred_batch_transformed = transform_space(pred_batch) return self.loss_fn(true_batch_transformed, pred_batch_transformed)