MirroredStrategy with jit_compile=True

Hello,

I found out that jit_compile=True makes the model run much faster. At the same time the MirroredStrategy is a good way to process the data faster. When I want to combine the two approaches with 2 GPU or more , it is not working fine. Why this is the case? Is it a bug? Can it be avoided?

I have this error message: “UnimplementedError: We failed to lift variable creations out of this tf.function, so this tf.function cannot be run on XLA. A possible workaround is to move variable creation outside of the XLA compiled function.”

strategy = tf.distribute.MirroredStrategy()

with strategy.scope():
model = …

model.compile(loss=..,
              optimizer=...,
              jit_compile=True)

model.fit(train_dataset, epochs=10)

I am using tf 2.15 with keras 3.0.4

Hi @Alexandre_Moritz

Welcome to the TensorFlow Forum!

Could you please share the minimal reproducible code to replicate the error and understand the issue better?

tf.function has some limitations on variable creation within a tf.function. It’s recommended to create variables outside the tf.function and with strategy.scope() and pass them as arguments to the function which allows XLA to compile the function effectively.

You can also alternatively try with different distributed training strategies like TPUStrategy or MultiWorkerMirroredStrategy that might be more compatible with XLA. Thank you.