For single strategy, to use gradient accumuation in custom fit method can be achieved by the following: Gradient Accumulation with Custom model.fit in TF.Keras?
But for multi-gpu (and TPU) cases, there are some complications, for example, it gives the following error.
merge_callcalled while defining a new graph or a tf.function. This can often happen if the function
strategy.run()contains a nested
@tf.function, and the nested
@tf.functioncontains a synchronization point, such as aggregating gradients (e.g, optimizer.apply_gradients), or if the function
fnuses a control flow statement which contains a synchronization point in the body. Such behaviors are not yet supported. Instead, please avoid nested
tf.functions or control flow statements that may potentially cross a synchronization boundary, for example, wrap the
strategy.runor the entire
tf.functionor move the control flow out of