Tensorflow memory leak in loop

there is a transformer model when i try to decode messages, translate input sentence to target sentence i get memory blow up, memory is good when i use transformer.fit(), but in a loop like below it blows up memory, tf.keras.backend.clear_session() does’t help, also accuracy decrease when i use that, gc.collect() doesn’t work also
here is my code

def decode_sequence(input_sentence):
    tokenized_input_sentence = input_vectorization([input_sentence])
    decoded_sentence = START_TOKEN
    for i in tf.range(max_decoded_sentence_length):
        tokenized_target_sentence = output_vectorization([decoded_sentence])#[:, :-1]
        
        predictions = transformer([tokenized_input_sentence, tokenized_target_sentence])
        
        
        sampled_token_index = np.argmax(predictions[0, i, :])
        sampled_token = output_index_lookup[sampled_token_index]
        decoded_sentence += sampled_token

        if sampled_token == END_TOKEN:
            break
    
   
    gc.collect()
    return decoded_sentence

from tqdm import tqdm
def overall_accuracy(pairs):
    corrects = 0
    inputs = pairs[2739:]
    iter = tqdm(inputs)
    for i, pair in enumerate(iter):
        input_text = pair[0]
        target = pair[1]
        predicted = decode_sequence(input_text)
        #guess = '✓' if predicted == target else '✗'
        #print('Sample Number : ', i, 'Predicted : ', predicted, 'Real : ', target, guess)
        if predicted == target:
            corrects += 1
        iter.set_postfix(corrects=corrects, accuracy=corrects / (i + 1))
    
    return corrects / len(inputs)

print("Overall Acurracy : ", overall_accuracy(test_pairs))```

Hi @ALER_EM ,

Could you please try below inputs on your problem.

Using tf.function : Convert your decode_sequence function into a TensorFlow graph operation using tf.function . This can optimize the function and might help manage memory better.

Accuracy Drop : The accuracy drop when using clear_session() might be due to the resetting of states or variables in your model. Ensure that the model state and weights are preserved correctly across calls to clear_session() .

gc.collect() is more effective for Python garbage collection and might not always help with TensorFlow’s memory management.

PLease let us know if it helps you.

Thanks