Issue with creating a custom loss function

Hi. I’m trying to fine-tune GPT2 text summary and my training example look like this:

Text: {text}\nSummary: {summary}

So I wanted to create a custom loss function that only cares tokens coming after word ‘Summary’

Here is my current code

Custom loss funcion

def custom_sparse_categorical_crossentropy(y_true, y_pred, from_logits=False, summary_token_id=None):
    scc = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=from_logits)
    mask = tf.math.cumsum(tf.cast(tf.equal(y_true, summary_token_id), dtype=tf.float32), axis=-1)
    sample_weight = tf.cast(mask > 0, dtype=tf.float32)  # This sets weight to 0 for all tokens before 'Summary' token
    return scc(y_true, y_pred, sample_weight=sample_weight)

Compiling the model and training

preprocessor = keras_nlp.models.GPT2CausalLMPreprocessor.from_preset(
gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset(
    "gpt2_base_en", preprocessor=preprocessor

num_epochs = 1

# Linearly decaying learning rate.
learning_rate = keras.optimizers.schedules.PolynomialDecay(
    decay_steps=train_ds.cardinality() * num_epochs,

    loss=lambda y_true, y_pred: custom_sparse_categorical_crossentropy(y_true, y_pred, from_logits=True, summary_token_id=summary_token_id),
), epochs=num_epochs)

But the training result is very strange. I got loss: 0.0585 - accuracy: 0.5588 and accuracy moved very little during training.

When I trained with just default loss function(keras.losses.SparseCategoricalCrossentropy(from_logits=True)) it worked pretty well getting loss of 1.3476 and accuracy of 0.6380 after one epoch.

Can anyone help me figure out what I’ve done wrong? Thanks