Help - Memory accumulation inside keras training loop

I am uncertain how this is getting bigger over time, can anyone with a good eye for keras point it out?

os.environ["KERAS_BACKEND"] = "jax"

import numpy as np
from tensorflow.keras.layers import Input, Flatten, Dense, Reshape
from tensorflow.keras import Model
import tensorflow as tf

def create_critic(n, m, d, population_dummy, scores_dummy, output_dummy):
    # Create an additional input for the scores and actor's output
    score_input = tf.keras.layers.Input(shape=(n,))
    actor_output_input = tf.keras.layers.Input(shape=output_dummy.shape[1:])

    # Define the critic model
    critic_input = tf.keras.layers.Input(shape=(n, m, d))

    x1 = Flatten()(critic_input)
    x2 = Dense(64, activation='relu')(score_input)  # Dense layer for the scores
    x3 = Flatten()(actor_output_input)  # Flatten the actor's output
    x = tf.keras.layers.Concatenate()([x1, x2, x3])  # Concatenate the flattened critic input, score input, and actor's output
    x = Dense(64, activation='relu')(x)
    x = Dense(64, activation='relu')(x)
    x = Dense(1, activation='linear')(x)  # Output layer with single linear unit

    critic_model = tf.keras.models.Model([critic_input, score_input, actor_output_input], x)
    critic_model.compile(optimizer='adam', loss='mean_squared_error')  # Use MSE loss for value prediction

    return critic_model

def create_actor(n, m, d, total_parents, population_dummy, scores_dummy):
    # Create an additional input for the scores
    score_input = tf.keras.layers.Input(shape=(n,))

    # Define the actor model
    actor_input = tf.keras.layers.Input(shape=(n, m, d))

    x1 = Flatten()(actor_input)
    x2 = Dense(64, activation='relu')(score_input)  # Dense layer for the scores
    x = tf.keras.layers.Concatenate()([x1, x2])  # Concatenate the flattened actor input and the score input
    x = Dense(64, activation='relu')(x)
    x = Dense(64, activation='relu')(x)
    x = Dense(n, activation='relu')(x)  # Output layer with n linear units

    actor_model = tf.keras.models.Model([actor_input, score_input], x)
    actor_model.compile(optimizer='adam', loss='mean_squared_error')  # Use MSE loss for value prediction

    return actor_model

import numpy as np

#dummy functions to  generate fake data to develop the training pipeline
def pop_gen(b, n, m, d):
    return np.random.randint(2, size=(b, n, m, d))
def reward_gen():
    return np.random.rand
def scores_gen(n):
    return np.random.rand(1, n)

n = int(nInd[0])
m = int(nMarkers)
d = 2
total_parents = n*2
population_dummy = pop_gen(1, n, m, d)  # Extra dimension for batch size
scores_dummy =  scores_gen(n) # Extra dimension for batch size
#init actor and critic models
actor_model = create_actor(n,m,d,total_parents, population_dummy, scores_dummy)
# Feed the dummy data through the network to get an example for creating critic model
actor_output = actor_model([population_dummy, scores_dummy])
critic_model = create_critic(n,m,d, population_dummy, scores_dummy,actor_output)


for _ in range(1000):
    actor_output = actor_model([population_dummy, scores_dummy])
    policy = actor_output # find the policy
    parent_values, parent_indices = select_parents(policy) #select parents
    selected_parents = population_dummy[0][parent_indices.numpy()] #grab parents from our current population
    past_fitness = scores_gen(n)
    new_fitness = scores_gen(n)
    reward = new_fitness-past_fitness

    # update the actor
    with tf.GradientTape() as tape:
        new_policy = actor_model([population_dummy, scores_dummy], training=True)  # compute new policy with actor
        actor_loss = -tf.reduce_mean(critic_model([population_dummy, scores_dummy, new_policy]))  # compute actor loss
    # Get the gradients
    actor_grad = tape.gradient(actor_loss, actor_model.trainable_variables)
    # Update the weights
    actor_model.optimizer.apply_gradients(zip(actor_grad, actor_model.trainable_variables))

    # update the critic
    with tf.GradientTape() as tape:
        critic_value = critic_model([population_dummy, scores_dummy, actor_output], training=True)  # compute critic value
        critic_loss = tf.keras.losses.MSE(reward, critic_value)  # compute critic loss
    # Get the gradients
    critic_grad = tape.gradient(critic_loss, critic_model.trainable_variables)
    # Update the weights
    critic_model.optimizer.apply_gradients(zip(critic_grad, critic_model.trainable_variables))

    population_dummy = pop_gen(1, n, m, d)
    scores_dummy = scores_gen(n)


The memory accumulation issue in your Keras training loop could be caused by various factors including inefficient data pipeline handling, model complexity, large batch sizes, use of global variables, tensor leakage, integration issues with the JAX backend, or inefficient callbacks and custom logs. To address this, you should profile memory usage, simplify your data pipeline, adjust batch sizes, isolate components of your training script, ensure stable TensorFlow/JAX integration, use the latest library versions, review any custom code, and consider clearing the TensorFlow session after training iterations. Identifying the exact cause might require a more detailed examination of the data handling process and training loop specifics.