Advice on speeding up very slow XLA Compile Times with `jit_compile=True`


I have a proprietary model that I can’t share source code for, written in TF and Keras.

When I pass in jit_compile=True to model.compile() I get all the expected speed ups, significantly better GPU utilisation, etc., however, the initial compilation phase is unbelievably long… like up to 1hr before first model training step.

Is there some simple trick I am missing out on to speed up compilation? e.g., if same function is called multipled times should I be decorating with tf.function? This seems to cause errors related to compilation of a function that is already compiled.

Any tips?

Apologies for vagueness, but I do feel like someone might have some wisdom to share without a repro on this.

Kind Regards,