I faced the same problem some year ago, and we came up with this solution:
From ashpy/executor.py:
@staticmethod
def reduce_loss(call_fn: Callable) -> Callable:
"""
Create a Decorator to reduce Losses. Used to simplify things.
Apply a ``reduce sum`` operation to the loss and divide the result
by the batch size.
Args:
call_fn (:py:obj:`typing.Callable`): The executor call method.
Return:
:py:obj:`typing.Callable`: The decorated function.
"""
# decorator definition
def _reduce(self, *args, **kwargs):
return tf.nn.compute_average_loss(
call_fn(self, *args, **kwargs),
global_batch_size=self._global_batch_size, # pylint: disable=protected-access
)
return _reduce
and we apply this decorator to the losses - in particular, for the L1 loss used during the adversarial training we use it in this way
class GeneratorL1(GANExecutor):
r"""
L1 loss between the generator output and the target.
.. math::
L_G = E ||x - G(z)||_1
Where x is the target and G(z) is generated image.
"""
def __init__(self) -> None:
"""Initialize the Executor."""
super().__init__(L1())
@Executor.reduce_loss
def call(self, context: GANContext, *, fake: tf.Tensor, real: tf.Tensor, **kwargs):
"""
Call the carried loss on `fake` and `real`.
Args:
context (:py:class:`ashpy.contexts.GANContext`): GAN Context.
fake (:py:class:`tf.Tensor`): Fake data (generated).
real (:py:class:`tf.Tensor`): Real data.
Returns:
:py:class:`tf.Tensor`: Output Tensor.
"""
mae = self._fn(fake, real)
return mae
Using this approach we were able to scale the training on multiple GPUs obtaining the same numerical result of doing the same training on a single GPU (with the same batch size = sum of the batch size used in the single GPU).
Hope it helps
PS: you can find some good insight about adversarial training, gans, distributed training and distribution strategy in the ashpy project - feel free to refer to it or use part of the code 