For single strategy, to use gradient accumuation in custom fit method can be achieved by the following: Gradient Accumulation with Custom model.fit in TF.Keras?
But for multi-gpu (and TPU) cases, there are some complications, for example, it gives the following error.
RuntimeError:
merge_call
called while defining a new graph or a tf.function. This can often happen if the functionfn
passed tostrategy.run()
contains a nested@tf.function
, and the nested@tf.function
contains a synchronization point, such as aggregating gradients (e.g, optimizer.apply_gradients), or if the functionfn
uses a control flow statement which contains a synchronization point in the body. Such behaviors are not yet supported. Instead, please avoid nestedtf.function
s or control flow statements that may potentially cross a synchronization boundary, for example, wrap thefn
passed tostrategy.run
or the entirestrategy.run
inside atf.function
or move the control flow out offn