Running out of GPU memory in custom training loop

Hi,

The following custom training loop raises the error: ‘Allocator (GPU_0_bfc) ran out of memory’ after a couple of epochs.

gwd_model = tf.keras.Sequential([layers.Dense(384, activation = "relu"),
                               layers.Dense(16, activation = "relu"),
                               layers.Dense(16, activation = "relu"),
                               layers.Dense(1, activation = return_scaled_sigmoid)])
optimizer = tf.keras.optimizers.AdamW()

@tf.function
def apply_gradients(gradients):
    optimizer.apply_gradients(zip(gradients, gwd_model.trainable_variables))

NEPOCHS = 40
NEPOCHS_MAX = 200
#os.remove('/tf/stop')

loss_function = tf.keras.losses.Huber()

ntraining_batches = len(list(training_dataset))
nval_batches = len(list(val_dataset))
print(f"Number of training batches: {ntraining_batches}, number of validation batches: {nval_batches}")

tf.keras.backend.set_value(optimizer.learning_rate, 1.0e-3)

best_val_loss = float('inf')
patience = 2  # Number of epochs to wait before reducing LR
wait = 0      # Counter for epochs waited
factor = 0.96  # Factor by which to reduce LR

for epoch in range(NEPOCHS_MAX):    
    ntraining_batch = int(epoch * ntraining_batches / NEPOCHS)
    ntraining_batch = max(ntraining_batch, 1) 
    ntraining_batch = min(ntraining_batch, ntraining_batches) 
    #ntraining_batch = ntraining_batches
    nval_batch = int(epoch * nval_batches / NEPOCHS)
    nval_batch = max(nval_batch, 1) 
    nval_batch = min(nval_batch, nval_batches) 
    #nval_batch = nval_batches
                          
    print(f"Epoch: {epoch}, Training batches: {ntraining_batch}, Validation batches: {nval_batch}")
    
    loss_mean = tf.keras.metrics.Mean()
    
    current_time = datetime.now()
    current_time_string = current_time.strftime("%H%M%S-%Y%m%d")
    print("current_time:", current_time_string)

    dataset = training_dataset.take(ntraining_batch)

    for dataset_features, dataset_labels in dataset:
        #print("ibatch", ibatch, " nbatches=", ntraining_batches)
        #loss = train_step(gwd_model, dataset_features, dataset_labels);
        with tf.GradientTape() as tape:
            predictions = gwd_model(dataset_features, training=True)
            loss = loss_function(dataset_labels, predictions)
        gradients = tape.gradient(loss, gwd_model.trainable_variables)
        apply_gradients(gradients)
        #optimizer.apply_gradients(zip(gradients, gwd_model.trainable_variables))
        loss_mean.update_state(loss)
            
    loss = loss_mean.result()

    val_loss_mean = tf.keras.metrics.Mean()

    current_time = datetime.now()
    current_time_string = current_time.strftime("%H%M%S-%Y%m%d")
    print("current_time:", current_time_string)
    
    dataset = val_dataset.take(nval_batch)
    for dataset_features, dataset_labels in dataset:
        predictions = gwd_model(dataset_features, training=False)
        val_loss = loss_function(dataset_labels, predictions)
        val_loss_mean.update_state(val_loss)

    val_loss = val_loss_mean.result()

    learning_rate = optimizer.learning_rate
    
    print(f"Epoch {epoch}: Loss: {loss.numpy():.4e}, Validation Loss: {val_loss.numpy():.4e}, Learning Rate: {learning_rate.numpy():.4e}")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        wait = 0
    else:
        wait += 1
        if wait >= patience:
            learning_rate = learning_rate * factor
            tf.keras.backend.set_value(optimizer.learning_rate, learning_rate)
            print(f"New learning rate: {learning_rate:.4e}.")
            wait = 0

    logs = {'loss': loss, 'val_loss': val_loss}
    
    current_time = datetime.now()
    current_time_string = current_time.strftime("%H%M%S-%Y%m%d")
    layer_units = [str(layer.units) for layer in gwd_model.layers if hasattr(layer, 'units')]
    layer_units_string = 'x'.join(layer_units)
    postfix = f"{layer_units_string}-epoch-{epoch+1}-val_loss-{val_loss:.8f}-{current_time_string}.csv"
    for ilayer, layer in enumerate(gwd_model.layers):
        pd.DataFrame(layer.weights[0]).to_csv(f"/tmp5/gwies/tf/weights{ilayer}-{postfix}", header=False, index=False)
        pd.DataFrame(layer.weights[1]).to_csv(f"/tmp5/gwies/tf/bias{ilayer}-{postfix}" , header=False, index=False)

    if (epoch > NEPOCHS) and (val_loss < 1.0e-5):
        break
        
    if learning_rate < 1.0e-6:
        break
        
    stop = '/tf/stop'
    if os.path.exists(stop):
        print(f"\nStopping training as '{stop}' exists.")
        break

The batch size is 65536 and the loop tries to allocate a tensor with shape [65536, 384] which is the [batch size, the number of units of the first dense layer] of type float in the context of ReluGrad. Should such a tensor not have been allocated once before training not after epoch number 14? The allocator also prints the heap, and there are about 20.000 objects with size 262144 (which looks like 65536 floats). The number of batches is slowly increased in each epoch, but if I train the model using gwd_model.fit() all batches are used at epoch 1 and I do not get the error even after 100 epochs.

Any suggestions what could be causing allocator to run out of memory on the GPU?

Regards,
GW

I noticed that before training starts ‘nvtop’ showed that the python3 process CPU memory usage increased from 40GB (the training and validation datasets reside completely in memory) to 120G. This was caused by the lines:

ntraining_batches = len(list(training_dataset))
nval_batches = len(list(val_dataset))

Changing those lines to:

ntraining_batches = len(training_dataset)
nval_batches = len(val_dataset)

solved that. However, the training still raises the error: ‘Allocator (GPU_0_bfc) ran out of memory’ after 14 epochs.

Any ideas?

Regards,
GW

In order to troubleshoot the issue I added the following statements to the script:

At the beginning:

tf.config.experimental.set_memory_growth(physical_devices[0], True)

def query_compute_apps():
    print("GPU MEMORY USAGE")
    result = subprocess.run(['nvidia-smi', '--query-compute-apps=pid,used_memory', '--format=csv'], stdout=subprocess.PIPE)
    output = result.stdout.decode('utf-8')
    print(output)

This allows you to monitor how much GPU memory is allocated. I call query_compute_apps() before and after training and before and after validation:

After training epoch 1: 708 MiB
After training epoch 4: 1220 MiB
After training epoch 7: 2244 MiB
After training epoch 11: 4292 MiB
After training epoch 15: 5852 MiB
In epoch 18:
2024-01-15 13:49:09.220088: W external/local_tsl/tsl/framework/bfc_allocator.cc:497] _***************************************************
2024-01-15 13:49:09.220248: W tensorflow/core/framework/op_kernel.cc:1839] OP_REQUIRES failed at numeric_op.h:82 : RESOURCE_EXHAUSTED: OOM when allocating tensor with shape[65536,384] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc
Traceback (most recent call last):
File “/tmp5/gwies/docker/gpu/neural-4-gpu.py”, line 145, in
gradients = tape.gradient(loss, gwd_model.trainable_variables)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/usr/local/lib/python3.11/dist-packages/tensorflow/python/eager/backprop.py”, line 1066, in gradient
flat_grad = imperative_grad.imperative_grad(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/usr/local/lib/python3.11/dist-packages/tensorflow/python/eager/imperative_grad.py”, line 67, in imperative_grad
return pywrap_tfe.TFE_Py_TapeGradient(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/usr/local/lib/python3.11/dist-packages/tensorflow/python/eager/backprop.py”, line 148, in _gradient_function
return grad_fn(mock_op, *out_grads)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/usr/local/lib/python3.11/dist-packages/tensorflow/python/ops/nn_grad.py”, line 414, in _ReluGrad
return gen_nn_ops.relu_grad(grad, op.outputs[0])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/usr/local/lib/python3.11/dist-packages/tensorflow/python/ops/gen_nn_ops.py”, line 11788, in relu_grad
_ops.raise_from_not_ok_status(e, name)
File “/usr/local/lib/python3.11/dist-packages/tensorflow/python/framework/ops.py”, line 5883, in raise_from_not_ok_status
raise core._status_to_exception(e) from None # pylint: disable=protected-access
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
tensorflow.python.framework.errors_impl.ResourceExhaustedError: {{function_node _wrapped__ReluGrad_device/job:localhost/replica:0/task:0/device:GPU:0}} OOM when allocating tensor with shape[65536,384] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc [Op:ReluGrad] name:

If I call model.fit with the same inputs, the same batch size etc. GPU usage starts at 456 MiB and stays constant.

I now added query_compute_apps before and after each statement in the training step:

Epoch: 4, Training batches: 453, Validation batches: 56
GPU MEMORY USAGE
pid, used_gpu_memory [MiB]
46221, 708 MiB

Training of 453 batches will take about 102.34 seconds
2024-01-15 19:08:02.485383: W external/local_tsl/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 7134440256 exceeds 10% of free system memory.
Before predictions
GPU MEMORY USAGE
pid, used_gpu_memory [MiB]
46221, 708 MiB

After predictions
GPU MEMORY USAGE
pid, used_gpu_memory [MiB]
46221, 708 MiB

After loss
GPU MEMORY USAGE
pid, used_gpu_memory [MiB]
46221, 708 MiB

After gradients
GPU MEMORY USAGE
pid, used_gpu_memory [MiB]
46221, 1220 MiB

After apply_gradients
GPU MEMORY USAGE
pid, used_gpu_memory [MiB]
46221, 1220 MiB

So the culprit is this statement:

gradients = tape.gradient(loss, gwd_model.trainable_variables)

I already added ‘del gradients’ after apply_gradients but the issue remains.

Any suggestions?

Regards,
GW

I’ve also run into unbounded memory usage when using a GradientTape context. In my case I concluded that the problem was in the model.fit and optimizer.apply_gradients methods. The only work-around that I have found is to downgrade to tensorflow 2.9. That’s the newest version which doesn’t feature these memory leaks.

If you need a newer version you could try the following: do a bit of training in a subprocess, then pass the model weights back to the main process, then kill the current subprocess and pass the weights to a newly created one. If you’re on Linux you’ll need to have multiprocessing.set_start_method(‘spawn’) before any other function calls in your script. TensorFlow doesn’t like being forked, but also seems to randomly crash, even when ‘spawn’ is used :upside_down_face:.

Hi,

Thank you, I can confirm that version 2.9.3 does not have this memory leak!

Regards,
GW

Hi,

I have been testing other version as well, and the latest version that does not have this memory leak is 2.13.0, the first version that has it is 2.14.0.

Regards,
GW

For completeness, when I refer to a Tensorflow version I am referring to the version that runs within the corresponding Docker container under Linux.

Regards,
GW