Nan loss occurring when training transformer model for machine translation

I am trying to train my model, I had no issues building it but the gradients just seem to not be computing, I have tried gradient clipping and switching optimizes but they did not work I also have filtered my data to make sure no Nan values existed. Would be very helpful if someone could help me figure this out.

Code for Transformer :

import tensorflow as tf
from tensorflow.keras.layers import Dropout, MultiHeadAttention, LayerNormalization, Dense, Embedding, Input
import numpy as np

def positional_encoding(max_seq_len, d_model):
pos_enc = np.zeros((max_seq_len, d_model))

for pos in range(max_seq_len):
    for i in range(0, d_model, 2):    
        pos_enc[pos, i] = np.sin(pos / np.power(10000, (2 * i) / d_model))
        if i + 1 < d_model:
            pos_enc[pos, i + 1] = np.cos(pos / np.power(10000, (2 * i) / d_model))
return pos_enc

def create_padding_mask(seq):
mask = tf.cast(tf.math.equal(seq, 0), tf.float32)
return mask[:, tf.newaxis, tf.newaxis, :]

def encoder_layer(input, d_model, num_heads, dff, mask, training, dropout_rate=0.1):
mha_output = MultiHeadAttention(num_heads, d_model, dropout=dropout_rate)(input, input, input, attention_mask=mask)
layernorm1 = LayerNormalization(epsilon=1e-6)(input + mha_output)

ffn = Dense(dff, activation='relu')(layernorm1)
ffn = Dense(d_model)(ffn)
ffn = Dropout(dropout_rate)(ffn, training=training)

output = LayerNormalization(epsilon=1e-6)(layernorm1 + ffn)

return output

def encoder(input, d_model, num_heads, num_layers, dff, mask, training, max_seq_len, vocab_size, dropout_rate=0.1):
emb_output = Embedding(vocab_size, d_model)(input)
emb_output *= tf.math.sqrt(tf.cast(d_model, tf.float32))
pos_out = positional_encoding(max_seq_len, d_model)
emb_output += pos_out[np.newaxis, :, :]
emb_output = Dropout(dropout_rate)(emb_output, training=training)
enc_output = emb_output
for i in range(num_layers):
enc_output = encoder_layer(enc_output, d_model, num_heads, dff, mask, training, dropout_rate)
return(enc_output)

def decoder_layer(input, d_model, num_heads, dff, training, padding_mask, enc_output, dropout_rate=0.1):
mha1_output, attn_weights1 = MultiHeadAttention(num_heads, d_model, dropout=dropout_rate)(input, input, input, use_causal_mask=True, return_attention_scores=True)
layernorm1 = LayerNormalization(epsilon=1e-6)(input + mha1_output)
mha2_output, attn_weights2 = MultiHeadAttention(num_heads, d_model, dropout=dropout_rate)(layernorm1, enc_output, enc_output, padding_mask, True)
layernorm2 = LayerNormalization(epsilon=1e-6)(layernorm1 + mha2_output)

ffn = Dense(dff, activation='relu')(layernorm2)
ffn = Dense(d_model)(ffn)
ffn = Dropout(dropout_rate)(ffn, training=training)

output = LayerNormalization(epsilon=1e-6)(layernorm2 + ffn)

return output, attn_weights1, attn_weights2

def decoder(input, d_model, num_layers, num_heads, dff, training, max_seq_len, padding_mask, vocab_size, enc_output, dropout_rate=0.1):
attention_weights = {}
emb_output = Embedding(vocab_size, d_model)(input)
emb_output *= tf.math.sqrt(tf.cast(d_model, tf.float32))
pos_out = positional_encoding(max_seq_len, d_model)
emb_output += pos_out[np.newaxis, :, :]
emb_output = Dropout(dropout_rate)(emb_output, training=training)
dec_outut = emb_output

for i in range(num_layers):
    dec_outut, block1, block2 = decoder_layer(dec_outut, d_model, num_heads, dff, training, padding_mask, enc_output, dropout_rate)
    attention_weights['decoder_layer{}_block1_self_att'.format(i+1)] = block1
    attention_weights['decoder_layer{}_block2_self_att'.format(i+1)] = block2

return dec_outut, attention_weights    

def Transformer(num_layers, d_model, num_heads, dff, training, en_vocab_size, ta_vocab_size, max_seq_len, dropout_rate=0.1):
input = Input(shape=(max_seq_len,), dtype=‘int32’, name=‘inputs’)
target = Input(shape=(max_seq_len,), dtype=‘int32’, name=‘targets’)

en_mask = create_padding_mask(input)
ta_mask = create_padding_mask(target)

encoder_output = encoder(input, d_model, num_heads, num_layers, dff, en_mask, training, max_seq_len, en_vocab_size, dropout_rate=dropout_rate)
decoder_output, _ = decoder(target, d_model, num_layers, num_heads, dff, training, max_seq_len, ta_mask, ta_vocab_size, encoder_output, dropout_rate=dropout_rate)

outputs = Dense(ta_vocab_size, activation='softmax')(decoder_output)

return tf.keras.models.Model(inputs=[input, target], outputs=outputs)

class MaskedSparseCategoricalCrossentropy(tf.keras.losses.Loss):
def init(self, from_logits=False, reduction=tf.keras.losses.Reduction.AUTO, name=‘masked_sparse_categorical_crossentropy’):
super().init(reduction=reduction, name=name)
self.sparse_categorical_crossentropy = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=from_logits, reduction=tf.keras.losses.Reduction.NONE)
def call(self, y_true, y_pred):
mask = tf.math.not_equal(y_true, 0)
loss = self.sparse_categorical_crossentropy(y_true, y_pred)
mask = tf.cast(mask, dtype=loss.dtype)
loss *= mask
return tf.reduce_sum(loss) / (tf.reduce_sum(mask) + tf.keras.backend.epsilon())

Code for Training:

import pickle
import numpy as np
from collections import Counter
from transformer import Transformer
import tensorflow as tf
from transformer import MaskedSparseCategoricalCrossentropy

with open(‘en-ta/English.txt’, ‘r’) as file:
en_sentences = file.readlines()
with open(‘en-ta/Tamil.txt’, ‘r’) as file:
ta_sentences = file.readlines()

TOTAL_SENTENCES = 200000
en_sentences = en_sentences[:TOTAL_SENTENCES]
ta_sentences = ta_sentences[:TOTAL_SENTENCES]
en_sentences = [sentence.rstrip(‘\n’).lower() for sentence in en_sentences]
ta_sentences = [sentence.rstrip(‘\n’) for sentence in ta_sentences]

max(len(x) for x in ta_sentences), max(len(x) for x in en_sentences)

Assuming ta_sentences and en_sentences are lists of strings (sentences)

Find the longest Tamil sentence

longest_ta_sentence = max(ta_sentences, key=len)

Find the longest English sentence

longest_en_sentence = max(en_sentences, key=len)

print(“Longest Tamil sentence:”, longest_ta_sentence)
print(“Longest English sentence:”, longest_en_sentence)

%%

PERCENTILE = 97
print( f"{PERCENTILE}th percentile length Tamil: {np.percentile([len(x) for x in ta_sentences], PERCENTILE)}" ) # roughly 250
print( f"{PERCENTILE}th percentile length English: {np.percentile([len(x) for x in en_sentences], PERCENTILE)}" ) # roughly 250

%%

START_TOKEN = ‘’
PADDING_TOKEN = ‘’
END_TOKEN = ‘’
UNKNOWN_TOKEN = ‘’

%%

ta_vocab = [PADDING_TOKEN, START_TOKEN, ’ ‘, ‘!’, ‘"’, ‘#’, ‘$’, ‘%’, ‘&’, "’", ‘(’, ‘)’, ‘*’, ‘+’, ‘,’, ‘-’, ‘.’, ‘/’,
‘0’, ‘1’, ‘2’, ‘3’, ‘4’, ‘5’, ‘6’, ‘7’, ‘8’, ‘9’, ‘:’, ‘<’, ‘=’, ‘>’, ‘?’, ‘ˌ’,
‘ஃ’, ‘அ’, ‘ஆ’, ‘இ’, ‘ஈ’, ‘உ’, ‘ஊ’, ‘எ’, ‘ஏ’, ‘ஐ’, ‘ஒ’, ‘ஓ’, ‘ஔ’, ‘க்’, ‘க’, ‘கா’, ‘கி’, ‘கீ’, ‘கு’, ‘கூ’, ‘கெ’,
‘கே’, ‘கை’, ‘கொ’, ‘கோ’, ‘கௌ’, ‘ங்’, ‘ங’, ‘ஙா’, ‘ஙி’, ‘ஙீ’, ‘ஙு’, ‘ஙூ’, ‘ஙெ’, ‘ஙே’, ‘ஙை’, ‘ஙொ’, ‘ஙோ’, ‘ஙௌ’, ‘ச்’,
‘ச’, ‘சா’, ‘சி’, ‘சீ’, ‘சு’, ‘சூ’, ‘செ’, ‘சே’, ‘சை’, ‘சொ’, ‘சோ’, ‘சௌ’,
‘ஞ்’, ‘ஞ’, ‘ஞா’, ‘ஞி’, ‘ஞீ’, ‘ஞு’, ‘ஞூ’, ‘ஞெ’, ‘ஞே’, ‘ஞை’, ‘ஞொ’, ‘ஞோ’, ‘ஞௌ’,
‘ட்’, ‘ட’, ‘டா’, ‘டி’, ‘டீ’, ‘டு’, ‘டூ’, ‘டெ’, ‘டே’, ‘டை’, ‘டொ’, ‘டோ’, ‘டௌ’,
‘ண்’, ‘ண’, ‘ணா’, ‘ணி’, ‘ணீ’, ‘ணு’, ‘ணூ’, ‘ணெ’, ‘ணே’, ‘ணை’, ‘ணொ’, ‘ணோ’, ‘ணௌ’,
‘த்’, ‘த’, ‘தா’, ‘தி’, ‘தீ’, ‘து’, ‘தூ’, ‘தெ’, ‘தே’, ‘தை’, ‘தொ’, ‘தோ’, ‘தௌ’,
‘ந்’, ‘ந’, ‘நா’, ‘நி’, ‘நீ’, ‘நு’, ‘நூ’, ‘நெ’, ‘நே’, ‘நை’, ‘நொ’, ‘நோ’, ‘நௌ’,
‘ப்’, ‘ப’, ‘பா’, ‘பி’, ‘பீ’, ‘பு’, ‘பூ’, ‘பெ’, ‘பே’, ‘பை’, ‘பொ’, ‘போ’, ‘பௌ’,
‘ம்’, ‘ம’, ‘மா’, ‘மி’, ‘மீ’, ‘மு’, ‘மூ’, ‘மெ’, ‘மே’, ‘மை’, ‘மொ’, ‘மோ’, ‘மௌ’,
‘ய்’, ‘ய’, ‘யா’, ‘யி’, ‘யீ’, ‘யு’, ‘யூ’, ‘யெ’, ‘யே’, ‘யை’, ‘யொ’, ‘யோ’, ‘யௌ’,
‘ர்’, ‘ர’, ‘ரா’, ‘ரி’, ‘ரீ’, ‘ரு’, ‘ரூ’, ‘ரெ’, ‘ரே’, ‘ரை’, ‘ரொ’, ‘ரோ’, ‘ரௌ’,
‘ல்’, ‘ல’, ‘லா’, ‘லி’, ‘லீ’, ‘லு’, ‘லூ’, ‘லெ’, ‘லே’, ‘லை’, ‘லொ’, ‘லோ’, ‘லௌ’,
‘வ்’, ‘வ’, ‘வா’, ‘வி’, ‘வீ’, ‘வு’, ‘வூ’, ‘வெ’, ‘வே’, ‘வை’, ‘வொ’, ‘வோ’, ‘வௌ’,
‘ழ்’, ‘ழ’, ‘ழா’, ‘ழி’, ‘ழீ’, ‘ழு’, ‘ழூ’, ‘ழெ’, ‘ழே’, ‘ழை’, ‘ழொ’, ‘ழோ’, ‘ழௌ’,
‘ள்’, ‘ள’, ‘ளா’, ‘ளி’, ‘ளீ’, ‘ளு’, ‘ளூ’, ‘ளெ’, ‘ளே’, ‘ளை’, ‘ளொ’, ‘ளோ’, ‘ளௌ’,
‘ற்’, ‘ற’, ‘றா’, ‘றி’, ‘றீ’, ‘று’, ‘றூ’, ‘றெ’, ‘றே’, ‘றை’, ‘றொ’, ‘றோ’, ‘றௌ’,
‘ன்’, ‘ன’, ‘னா’, ‘னி’, ‘னீ’, ‘னு’, ‘னூ’, ‘னெ’, ‘னேனை’,
‘ஶ்’, ‘ஶ’, ‘ஶா’, ‘ஶி’, ‘ஶீ’, ‘ஶு’, ‘ஶூ’, ‘ஶெ’, ‘ஶே’, ‘ஶை’, ‘ஶொ’, ‘ஶோ’, ‘ஶௌ’,
‘ஜ்’, ‘ஜ’, ‘ஜா’, ‘ஜி’, ‘ஜீ’, ‘ஜு’, ‘ஜூ’, ‘ஜெ’, ‘ஜே’, ‘ஜை’, ‘ஜொ’, ‘ஜோ’, ‘ஜௌ’,
‘ஷ்’, ‘ஷ’, ‘ஷா’, ‘ஷி’, ‘ஷீ’, ‘ஷு’, ‘ஷூ’, ‘ஷெ’, ‘ஷே’, ‘ஷை’, ‘ஷொ’, ‘ஷோ’, ‘ஷௌ’,
‘ஸ்’, ‘ஸ’, ‘ஸா’, ‘ஸி’, ‘ஸீ’, ‘ஸு’, ‘ஸூ’, ‘ஸெ’, ‘ஸே’, ‘ஸை’, ‘ஸொ’, ‘ஸோ’, ‘ஸௌ’,
‘ஹ்’, ‘ஹ’, ‘ஹா’, ‘ஹி’, ‘ஹீ’, ‘ஹு’, ‘ஹூ’, ‘ஹெ’, ‘ஹே’, ‘ஹை’, ‘ஹொ’, ‘ஹோ’, ‘ஹௌ’,
‘க்ஷ்’, ‘க்ஷ’, ‘க்ஷா’, ‘க்ஷ’, ‘க்ஷீ’, ‘க்ஷு’, ‘க்ஷூ’, ‘க்ஷெ’, ‘க்ஷே’, ‘க்ஷை’, ‘க்ஷொ’, ‘க்ஷோ’, ‘க்ஷௌ’,
‘்’, ‘ா’, ‘ி’, ‘ீ’, ‘ு’, ‘ூ’, ‘ெ’, ‘ே’, ‘ை’, ‘ொ’, ‘ோ’, ‘ௌ’,END_TOKEN]

%%

en_vocab = [PADDING_TOKEN, START_TOKEN, ’ ‘, ‘!’, ‘"’, ‘#’, ‘$’, ‘%’, ‘&’, "’", ‘(’, ‘)’, ‘*’, ‘+’, ‘,’, ‘-’, ‘.’, ‘/’,
‘0’, ‘1’, ‘2’, ‘3’, ‘4’, ‘5’, ‘6’, ‘7’, ‘8’, ‘9’,
‘:’, ‘<’, ‘=’, ‘>’, ‘?’, ‘@’,
‘[’, ‘\’, ‘]’, ‘^’, ‘_’, ‘`’,
‘a’, ‘b’, ‘c’, ‘d’, ‘e’, ‘f’, ‘g’, ‘h’, ‘i’, ‘j’, ‘k’, ‘l’,
‘m’, ‘n’, ‘o’, ‘p’, ‘q’, ‘r’, ‘s’, ‘t’, ‘u’, ‘v’, ‘w’, ‘x’,
‘y’, ‘z’, ‘{’, ‘|’, ‘}’, ‘~’, END_TOKEN]

def is_valid_token(sentence, vocab):
return all(token in vocab for token in sentence)

def find_invalid_tokens(sentence, vocab):
return [token for token in set(sentence) if token not in vocab]

def is_valid_length(sentence, max_sequence_length):
return len(sentence) <= max_sequence_length

ta_vocab = {v:k for k,v in enumerate(ta_vocab)}
en_vocab = {v:k for k,v in enumerate(en_vocab)}

invalid_tokens_list = []
valid_sentence_indices = []
invalid_sentence_indices = []

for index, (ta_sentence, en_sentence) in enumerate(zip(ta_sentences, en_sentences)):
invalid_ta_tokens = find_invalid_tokens(ta_sentence, ta_vocab)
invalid_en_tokens = find_invalid_tokens(en_sentence, en_vocab)

if is_valid_length(ta_sentence, 250) and is_valid_length(en_sentence, 250):
    if is_valid_token(ta_sentence, ta_vocab) and is_valid_token(en_sentence, en_vocab):
        valid_sentence_indices.append(index)
    else:
        invalid_tokens_list.append((invalid_ta_tokens, invalid_en_tokens))
        invalid_sentence_indices.append(index)

ta_sentences = [ta_sentences[i] for i in valid_sentence_indices]
en_sentences = [en_sentences[i] for i in valid_sentence_indices]

def text_to_indices(sequences, vocab):
sequences_to_ids = []
for sequence in sequences:
seq = [vocab[char] for char in sequence]
sequences_to_ids.append(seq)
return sequences_to_ids

def create_decoder_sequences(sequences, vocab, max_len=250):
sos_token = vocab[‘’]
eos_token = vocab[‘’]
pad_token = vocab[‘’]

decoder_input_seqs = []
decoder_output_seqs = []

for seq in sequences:
    input_seq = [sos_token] + seq
    output_seq = seq + [eos_token]
    
    
    decoder_input_seqs.append(input_seq)
    decoder_output_seqs.append(output_seq)
    
decoder_input_seqs = tf.keras.preprocessing.sequence.pad_sequences(decoder_input_seqs, maxlen=max_len, padding='post', truncating='post', value=pad_token)
decoder_output_seqs = tf.keras.preprocessing.sequence.pad_sequences(decoder_output_seqs, maxlen=max_len, padding='post', truncating='post', value=pad_token)

return decoder_input_seqs, decoder_output_seqs

en_input = tf.keras.preprocessing.sequence.pad_sequences(text_to_indices(en_sentences, en_vocab), maxlen=250, padding=‘post’, truncating=‘post’, value=en_vocab[‘’])
decoder_input, decoder_output = create_decoder_sequences(text_to_indices(ta_sentences, ta_vocab), ta_vocab)

en_vocab_size = 71
ta_vocab_size = 367
num_layers=1
num_heads=8
d_model=512
dff=2048
dropout_rate=0.1
training=True
max_seq_len=250
batch_size = 30

model = Transformer(num_layers, d_model, num_heads, dff, training, en_vocab_size, ta_vocab_size, max_seq_len, dropout_rate=dropout_rate)
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-5, clipvalue=1.0)

model.compile(optimizer=optimizer, loss=MaskedSparseCategoricalCrossentropy(), metrics=[‘acc’])

history = model.fit([en_input, decoder_input], decoder_output, epochs=1, batch_size=batch_size)