I have been training a decoder based transformer for word generation. But it keeps generating the same words over and over again

I have been trying to create a decoder based transformer for text generation and the text its generating is the same no matter the input sequence

The following is my code some of , the code for preprocessing was remove

def process_batch(ds):
    ds = tokenizer(ds)

    ## padd short senteces to max len using the [PAD] id
    ## add special tokens [START] and [END]

    ds_start_end_packer = StartEndPacker(
        sequence_length=MAX_SEQUENCE_LENGTH + 1,
        start_value = tokenizer.token_to_id("[START]"),
        end_value = tokenizer.token_to_id("[END]"),
        pad_value = tokenizer.token_to_id("[PAD]")
    )

    ds = ds_start_end_packer(ds)

    return ({"decoder_inputs":ds[:, :-1]}, ds[:, 1:])


def make_ds(seq):
    dataset = tf.data.Dataset.from_tensor_slices(seq)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.map(process_batch, num_parallel_calls=tf.data.AUTOTUNE)

    return dataset.shuffle(128).prefetch(32).cache()

train_ds = make_ds(train_seq)
val_ds = make_ds(val_seq)

This is the decoder section i was using keras_nlp
It have 2 decoders layers

decoder_inputs = Input(shape=(None,), dtype="int64", 
                          name="decoder_inputs")

x = TokenAndPositionEmbedding(
    vocabulary_size= VOCAB_SIZE,
    sequence_length = MAX_SEQUENCE_LENGTH,
    embedding_dim = EMBED_DIM,
    mask_zero =True
    )(decoder_inputs)


x = TransformerDecoder(
    intermediate_dim = INTERMEDIATE_DIM, num_heads= NUM_HEADS
 )(x)

x = TransformerDecoder(
    intermediate_dim = INTERMEDIATE_DIM, num_heads= NUM_HEADS
 )(x)

x = Dropout(0.5)(x)


decoder_ouput = Dense(VOCAB_SIZE, activation="softmax")(x)

decoder = Model([decoder_inputs],decoder_ouput)

decoder_outputs = decoder([decoder_inputs])



transformer = Model(inputs=decoder_inputs, outputs=decoder_outputs, name="transformer")
#transformer.load_weights("/content/my-drive/MyDrive/projects/Olsen/weights-improvement-07-0.41.hdf5")
transformer.compile("adam",loss="sparse_categorical_crossentropy", metrics=['accuracy'])