Model consuming RaggedTensors fails during evaluation in a distributed setting

We have a model that consumes multiple ragged tensors in a batch. Our model runs perfectly fine on a single GPU. But the moment we introduce distributed training, its evaluation fails.

Note that the training during the distributed settings proceeds smoothly but it’s during the evaluation it fails. Since we cannot provide the original data and model, we are using we are providing a minimal snippet in the following notebook that reproduces the issue. You can use this Colab to reproduce the issue as well as a multi-GPU machine. We have verified on both and the issue persists.

More details are available here:

Cc: @nilabhra