[Support] RuntimeError: merge_call called while defining a new graph or a tf.function

For single strategy, to use gradient accumuation in custom fit method can be achieved by the following: Gradient Accumulation with Custom model.fit in TF.Keras?

But for multi-gpu (and TPU) cases, there are some complications, for example, it gives the following error.

RuntimeError: merge_call called while defining a new graph or a tf.function. This can often happen if the function fn passed to strategy.run() contains a nested @tf.function , and the nested @tf.function contains a synchronization point, such as aggregating gradients (e.g, optimizer.apply_gradients), or if the function fn uses a control flow statement which contains a synchronization point in the body. Such behaviors are not yet supported. Instead, please avoid nested tf.function s or control flow statements that may potentially cross a synchronization boundary, for example, wrap the fn passed to strategy.run or the entire strategy.run inside a tf.function or move the control flow out of fn