Tf.nn.ctc_loss error on TPU with variable length audio data and padded_batch

I am following the guide here https://keras.io/examples/audio/ctc_asr/

Everything works on GPU. But it is painfully slow to train (5-6 hours for an epoch)

But I am encountering problems with ctc loss on TPU

2023-09-24 21:50:54.866041: E tensorflow/core/tpu/kernels/tpu_compilation_cache_external.cc:113] Input 0 to node `gradients/ctc_loss_dense/concat_grad/ConcatOffset` with op ConcatOffset must be a compile-time constant.

XLA compilation requires that operator arguments that represent shapes or dimensions be evaluated to concrete values at compile time. This error means that a shape or dimension argument could not be evaluated at compile time, usually because the value of the argument depends on a parameter to the computation, on a variable, or on a stateful operation such as a random number generator.

	 [[{{node gradients/ctc_loss_dense/concat_grad/ConcatOffset}}]]

My data pipeline

#use 12% random sample as training set
df_train=(
    grouped_df
    .get_group("train")
    .sample(
        int(len(grouped_df.get_group("train").index) * 0.15)
    )
)
df_val=(
    grouped_df
    .get_group("valid")
    .sample(
        int(len(grouped_df.get_group("valid").index) * 0.10)
    )
)

# Define the training dataset
train_dataset = tf.data.Dataset.from_tensor_slices(
    (list(df_train["file_path"]), list(df_train["sentence"]))
)
train_dataset = (
    train_dataset
    .map(decode_and_process_sample, num_parallel_calls=tf.data.AUTOTUNE)
    .padded_batch(ConfigData.BATCH_SIZE, drop_remainder=True)
    .prefetch(buffer_size=tf.data.AUTOTUNE)
    .cache(filename="/tmp/cache.train")
)

# Define the validation dataset
validation_dataset = tf.data.Dataset.from_tensor_slices(
    (list(df_val["file_path"]), list(df_val["sentence"]))
)
validation_dataset = (
    validation_dataset
    .map(decode_and_process_sample, num_parallel_calls=tf.data.AUTOTUNE)
    .padded_batch(ConfigData.BATCH_SIZE, drop_remainder=True)
    .prefetch(buffer_size=tf.data.AUTOTUNE)
    .cache(filename="/tmp/cache.valid")
)

Current ctc loss function


@tf.function(jit_compile=True)
def CTCLoss2(y_true, y_pred):

    batch_len = tf.cast(tf.shape(y_true)[0], dtype="int32")
    logit_length = tf.fill([tf.shape(y_pred)[0]], tf.shape(y_pred)[1])
    label_length = tf.fill([tf.shape(y_true)[0]], tf.shape(y_true)[1])
    
    return tf.math.reduce_mean(
        tf.nn.ctc_loss(
            logits=y_pred, labels=y_true, label_length=label_length, logit_length=logit_length, blank_index=-1, logits_time_major=False
        )
    )
    

What am I doing wrong? How do I solve this? :frowning_face:

Hi @anan5a, Generally XLA requires certain values to be known at compile time, such as reduction axis of a reduced operation, or transposition dimensions. Could you please try by passing default values to variables? Thank You.