Retracing with Distributed Training

Hi guys,

I am trying a custom model with distributed training on multiple GPU with tf.function(). The graph tends to compile at every call. To solve the issue i passed the input_signature argument with specified tf.TensorSpec() on the tf.function() which works fine for 1 gpu, however when i use multiple gpus, it returns the error ‘Perreplica does not have dtype’.

Please any idea, how i an solve this problem?