Seeing warning saying that I am not using `tf.function` when calling `tf.distribute.strategy.run` however I am

Hi TF developer community.

I’m a maintainer of a library for doing distributed reinforcement learning called RLlib. I’m writing a new distributed training stack using tensorflow 2.11. I have a question about a warning that I see when using tf.function with tf.distribute.strategy.run.

I frequently see the following warning:

WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead
currently. We will be working on improving this in the future, but for now
please wrap `call_for_each_replica` or `experimental_run` or `run` inside a
tf.function to get the best performance.

the pattern of code that roughly describes my setup (is not an exact reproduction for simplicity) is as follows:

def _do_update_fn(self, batch) -> Mapping[str, Any]:
        def helper(_batch):
            with tf.GradientTape() as tape:
                # This is an underlying feed forward call to multiple keras models
                fwd_out = self.keras_model_container.forward_train(_batch)
               # This is a loss computation
                loss = self.compute_loss(fwd_out=fwd_out, batch=_batch)
            gradients = self.compute_gradients(loss, tape)
            self.apply_gradients(gradients)
            return {
                "loss": loss,
                "fwd_out": fwd_out,
                "postprocessed_gradients": gradients,
            }
        # self.strategy is a tf.distribute.strategy object
        return self.strategy.run(helper, args=(batch,))


update_fn = tf.function(self._do_update_fn, reduce_retracing=True)

batch = ...
update_fn(batch)

Can anyone explain to me why I might be getting this warning despite the fact that I am using tf.function.

Thanks!