Doubts in loss scaling during distributed training

I had posted something similar before but couldn’t get around the problem.

I want to how to properly scale the final loss when it’s composed of two or several other individual loss terms as is common in many use-cases like Image Generation. For example, for image generation, a typical loss function is the following:

total_loss = (alpha * perceptual_loss) + (beta * l1_loss)

Note that the non-reduced versions of perceptual_loss and l1_loss do have the same shapes. So, I guess it becomes necessary to reduce them before summing them up. The official guide does provide some guidance but I could not find something concrete around this particular use-case.

Any pointers regarding this would be very helpful.

1 Like

I faced the same problem some year ago, and we came up with this solution:

From ashpy/

    def reduce_loss(call_fn: Callable) -> Callable:
        Create a Decorator to reduce Losses. Used to simplify things.
        Apply a ``reduce sum`` operation to the loss and divide the result
        by the batch size.
            call_fn (:py:obj:`typing.Callable`): The executor call method.
            :py:obj:`typing.Callable`: The decorated function.
        # decorator definition
        def _reduce(self, *args, **kwargs):
            return tf.nn.compute_average_loss(
                call_fn(self, *args, **kwargs),
                global_batch_size=self._global_batch_size,  # pylint: disable=protected-access

        return _reduce

and we apply this decorator to the losses - in particular, for the L1 loss used during the adversarial training we use it in this way

class GeneratorL1(GANExecutor):
    L1 loss between the generator output and the target.
    .. math::
        L_G = E ||x - G(z)||_1
    Where x is the target and G(z) is generated image.

    def __init__(self) -> None:
        """Initialize the Executor."""

    def call(self, context: GANContext, *, fake: tf.Tensor, real: tf.Tensor, **kwargs):
        Call the carried loss on `fake` and `real`.
            context (:py:class:`ashpy.contexts.GANContext`): GAN Context.
            fake (:py:class:`tf.Tensor`): Fake data (generated).
            real (:py:class:`tf.Tensor`): Real data.
            :py:class:`tf.Tensor`: Output Tensor.
        mae = self._fn(fake, real)
        return mae

Using this approach we were able to scale the training on multiple GPUs obtaining the same numerical result of doing the same training on a single GPU (with the same batch size = sum of the batch size used in the single GPU).

Hope it helps

PS: you can find some good insight about adversarial training, gans, distributed training and distribution strategy in the ashpy project - feel free to refer to it or use part of the code :slight_smile:

1 Like

It works for me for a single loss quantity. Like when I am only using L1 or any other loss and reducing it using tf.nn.compute_average_loss. But numerical issues start arising when we comprise different loss functions as is common for many image generation tasks.

My broader goal is to also keep the code as readable and as minimal as possible.

1 Like

We apply the reduce_loss decorator singularly for every function, and from our experiments (but I’m talking about 2 years ago, so something might have changed in the TensorFlow behavior) there are no numerical stability problem.


lambda1* l1 + lambda2 + l2 where lambda1,2 are scalars and l1 and l2 are loss functions decorated with the reduce_loss decorator

1 Like

Yes, that is what I had thought too. But from this guide:

If labels is multi-dimensional, then average the per_example_loss across the number of elements in each sample. For example, if the shape of predictions is (batch_size, H, W, n_classes) and labels is (batch_size, H, W) , you will need to update per_example_loss like: per_example_loss /= tf.cast(tf.reduce_prod(tf.shape(labels)[1:]), tf.float32).

In my case, the perceptual loss and the L1 loss both satisfy this criterion. So, this is what I did and things seem to be working fine:

perceptual_loss = ...
perceptual_loss /= tf.cast(tf.reduce_prod(tf.shape(activations)[1:]), tf.float32)
perceptual_loss = tf.nn.compute_average_loss(perceptual_loss, global_batch_size=BATCH_SIZE)
perceptual_loss = alpha * perceptual_loss

reconstruction_loss = tf.keras.losses.mean_absolute_error(color_images, reconstructed_images)
reconstruction_loss /= tf.cast(tf.reduce_prod(tf.shape(color_images)[1:]), tf.float32)
reconstruction_loss = tf.nn.compute_average_loss(reconstruction_loss, global_batch_size=BATCH_SIZE)
reconstruction_loss = beta * reconstruction_loss

total_loss = perceptual_loss * reconstruction_loss
1 Like