Confusion regarding implementation of `mirrored_run`

Hi, I’m currently trying to understand the internal functioning of the MirroredStrategy and the recording of gradients of MirroredVariables in particular.

I understand the concept of the MirroredVariable but it’s unclear to me how a correct gradient tape is recorded over these variables in _call_for_each_replica in mirrored_strategy. As this implementation seems mostly covered by mirrored_run I tried to mainly focus on this file instead. Say we have 1 MirroredVariable with the following signature:

MirroredVariable {
  0: <tf.Variable 'w:0' shape=() dtype=float32>,
  1: <tf.Variable 'w/replica_1:0' shape=() dtype=float32>
}

I’ve tried to understand this behavior by altering the _call_for_each_replica implementation so it runs every function sequentially on the defined device (just removing the replica threads). This works for variable creation, computation and reduction, but breaks when recording gradients. Say I have the following function:

@def_function.function
def step(x):
    with backprop.GradientTape() as tape:
        loss = w * x

    optimizer.minimize(loss, var_list=[w], tape=tape)

strategy.run(step, args=(2.0,))

This yields:

No gradients provided for any variable: ['w:0']

Adding tape.watch(w) doesn’t change anything and my guess is that it’s due to the function wrapping happening in call_for_each_replica in mirrored_run. Could anyone shine some light on how these gradients are recorded here?