Limiting recompilation of jax2tf model

I have a model I trained with jax, and then used jax2tf.convert to save that model and then load it with tf.saved_model.load. Inference works correctly, but I periodically see these 4 lines below coinciding with a pause in inference and a spike in GPU usage. I’m only ever calling model.__call__() in a tight loop with a relatively constant shape. I process multiple files, and each file generates on the order of 10k batches, so the last batch for a given file is often less than the typical batch size, and there are sometimes pauses while data is loaded from disk.

When doing inference at peak throughput my GPU usage hovers around 70% utilization, but when the pauses happen the utilization spikes to 100% for a few seconds.

Am I correct in assuming the pauses + messages + GPU spikes indicate model recompilation? If so, is there anything I can do to avoid this?

I0000 00:00:1710940593.002953 3467316 asm_compiler.cc:369] ptxas warning : Registers are spilled to local memory in function 'triton_gemm_dot_930', 40 bytes spill stores, 40 bytes spill loads
I0000 00:00:1710940593.070684 3467310 asm_compiler.cc:369] ptxas warning : Registers are spilled to local memory in function 'triton_gemm_dot_930', 4 bytes spill stores, 4 bytes spill loads
I0000 00:00:1710940593.215734 3467308 asm_compiler.cc:369] ptxas warning : Registers are spilled to local memory in function 'triton_gemm_dot_930', 276 bytes spill stores, 276 bytes spill loads
I0000 00:00:1710940593.460702 3467326 asm_compiler.cc:369] ptxas warning : Registers are spilled to local memory in function 'triton_gemm_dot_930', 180 bytes spill stores, 180 bytes spill loads

TensorFlow Solutions Architect

Yes, the pauses and GPU spikes likely indicate model recompilation, often due to variable input batch sizes. To minimize recompilations, try ensuring consistent batch sizes, use JIT compilation with static input shapes, and consider a warm-up phase to front-load compilations. Optimizing your model and input handling can also help reduce these issues.

Seems like tf.data.experimental.pad_to_cardinality  |  TensorFlow v2.15.0.post1 is relevant functionality here…