Vanilla Transformer error on training with TPU

I am trying to train a vanilla transformer according to the Vaswani et el
I got this error when training on tpu on colab. i trained it on CPU and GPU with jit_compile=True and it was working but not on TPU. and I have a small idea that some non-compatible tensor ops of TPU is giving the problem but i am unable to pinpoint it.

<ipython-input-15-f20674498ec4> in <cell line: 1>()
----> 1,validation_data=valid_ds,epochs=10,steps_per_epoch=train_steps,validation_steps=valid_steps)

1 frames
/usr/local/lib/python3.10/dist-packages/keras/src/utils/ in error_handler(*args, **kwargs)
     68             # To get the full stack trace, call:
     69             # `tf.debugging.disable_traceback_filtering()`
---> 70             raise e.with_traceback(filtered_tb) from None
     71         finally:
     72             del filtered_tb

/usr/local/lib/python3.10/dist-packages/tensorflow/core/function/capture/ in capture_by_value(self, graph, tensor, name)
    120         graph_const = self.by_val_internal.get(id(tensor))
    121         if graph_const is None:
--> 122           graph_const = tensor._capture_as_const(name)  # pylint: disable=protected-access
    123           if graph_const is None:
    124             # Some eager tensors, e.g. parallel tensors, are not convertible to

InternalError: failed to connect to all addresses; last error: UNKNOWN: ipv4: Failed to connect to remote host: Connection refused
Additional GRPC error information from remote target /job:localhost/replica:0/task:0/device:CPU:0:
:UNKNOWN:failed to connect to all addresses; last error: UNKNOWN: ipv4: Failed to connect to remote host: Connection refused {created_time:"2023-09-17T14:29:09.344781733+00:00", grpc_status:14}
Executing non-communication op <MultiDeviceIteratorInit> originally returned UnavailableError, and was replaced by InternalError to avoid invoking TF network error handling logic.

My code really clean. I really appreciate if you take time to look through it.
It is a spanish-to-english translation data as its really small(2.5MB) i thought i could cache the whole thing and train it on TPU
I followed this google’s Transformers_tutorial.ipynb and coded this
Thankyou very much in Advance.
As I am still working on that code. I will reply back right away for any of your suggestions.

In the past, Keras released CPU & GPU versions first, and then proofed out any quirks on TPUs. How old is your Keras/TF version, and is it exactly the same for CPU/GPU and TPU?

I have not tracked this for awhile and do not know if this suggestion will be helpful.

Also, you may have found the one place where cross-TPU communication has to be complex. I believe there is a TPU version of BatchNormalization because of this problem.