Getting memory error when training a larger dataset on the GPU

Hi,

The following model trains fine on the GPU with a 1M dataset:

history=model.fit(x_train, y_train,
                      batch_size=65536, epochs=1000,
                      callbacks=[callback],
                      validation_data=(x_test, y_test))

However, when I use a 10M dataset I am getting the error:

Failed copying input tensor from /job:localhost/replica:0/task:0/device:CPU:0 to /job:localhost/replica:0/task:0/device:GPU:0 in order to run _EagerConst: Dst tensor is not initialized.

Are x_train, y_train, x_test and y_test not copied in batches to the GPU?

Regards,
GW

Can you try reducing batch size? This error message is generated if the memory is not sufficient to manage the batch size.

Thank you!

Hi,

I forgot to mention that even if I reduce the batch size to 128 I am still getting the memory error, so I have the strong impression that x_train, y_train and x_test and y_test are copied in full to the GPU… Is there a way to have them copied in batches?

Regards,
GW

I don’t think the data is copied completely to the GPU (though maybe the validation data is?). It definitely won’t be if you use tf.data.Dataset  |  TensorFlow v2.11.0 on your dataset.

But it’s hard to say what’s wrong without more knowledge of the model you are building and the dataset.

Unrelated: Don’t use your test data as the validation data set. Split the validation data from the training data.

Hi,

Thank you, let me share the complete script:

csv="inputs-10M.csv"

features=pd.read_csv(csv, dtype = 'int8', converters = {'RESULT': float})

labels=features.pop('RESULT')

x_train, x_test, y_train, y_test=train_test_split(features.to_numpy(), labels.to_numpy(), test_size=0.2)

model=tf.keras.Sequential([layers.Dense(192,activation="relu"),
                               layers.Dense(16,activation="relu"),
                               layers.Dense(16,activation="relu"),
                               layers.Dense(1,activation="sigmoid")])

model.compile(optimizer=tf.keras.optimizers.Adam(),loss = tf.keras.losses.MeanSquaredError())

callback=tf.keras.callbacks.EarlyStopping(monitor='val_loss', mode='auto', patience=10)

history=model.fit(x_train, y_train,
                      batch_size=65536, epochs=1000,
                      callbacks=[callback],
                      validation_data=(x_test, y_test))

This runs fine when I use a CSV file with 1M records, but raises the error when I use a CSV file with 10M records, even if I reduce the batch_size to 128.

Any advice?

Regards,
GW

Does this work?

csv="inputs-10M.csv"

features=pd.read_csv(csv, dtype = 'int8', converters = {'RESULT': float})

labels=features.pop('RESULT')
print(len(features.columns))

x_train, x_test, y_train, y_test=train_test_split(features.to_numpy(), labels.to_numpy(), test_size=0.2)
train = Dataset.from_tensor_slices((x_train, y_train)).shuffle(4*128).batch(128)
validate = Dataset.from_tensor_slices((x_test, y_test)).batch(128)

model=tf.keras.Sequential([layers.Dense(192,activation="relu"),
                               layers.Dense(16,activation="relu"),
                               layers.Dense(16,activation="relu"),
                               layers.Dense(1,activation="sigmoid")])

model.compile(optimizer=tf.keras.optimizers.Adam(),loss = tf.keras.losses.MeanSquaredError())

callback=tf.keras.callbacks.EarlyStopping(monitor='val_loss', mode='auto', patience=10)

history=model.fit(train,
                      epochs=1000,
                      callbacks=[callback],
                      validation_data=validate)

Does features have the same number of columns for both datasets?

Hi,

Yes, features has the same number of columns for both datasets. When I execute:

train = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(4*128).batch(128)
validate = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(128)

I am getting the errors:

2022-12-01 12:44:29.498244: W tensorflow/tsl/framework/bfc_allocator.cc:479] Allocator (GPU_0_bfc) ran out of memory trying to allocate 1.43GiB (rounded to 1536092928)requested by op _EagerConst
If the cause is memory fragmentation maybe the environment variable 'TF_GPU_ALLOCATOR=cuda_malloc_async' will improve the situation. 
Current allocation summary follows.
Current allocation summary follows.
2022-12-01 12:44:29.498311: I tensorflow/tsl/framework/bfc_allocator.cc:1034] BFCAllocator dump for GPU_0_bfc
2022-12-01 12:44:29.498328: I tensorflow/tsl/framework/bfc_allocator.cc:1041] Bin (256): 	Total Chunks: 0, Chunks in use: 0. 0B allocated for chunks. 0B in use in bin. 0B client-requested in use in bin.
2022-12-01 12:44:29.498339: I tensorflow/tsl/framework/bfc_allocator.cc:1041] Bin (512): 	Total Chunks: 0, Chunks in use: 0. 0B allocated for chunks. 0B in use in bin. 0B client-requested in use in bin.
2022-12-01 12:44:29.498350: I tensorflow/tsl/framework/bfc_allocator.cc:1041] Bin (1024): 	Total Chunks: 0, Chunks in use: 0. 0B allocated for chunks. 0B in use in bin. 0B client-requested in use in bin.
2022-12-01 12:44:29.498359: I tensorflow/tsl/framework/bfc_allocator.cc:1041] Bin (2048): 	Total Chunks: 0, Chunks in use: 0. 0B allocated for chunks. 0B in use in bin. 0B client-requested in use in bin.
2022-12-01 12:44:29.498369: I tensorflow/tsl/framework/bfc_allocator.cc:1041] Bin (4096): 	Total Chunks: 0, Chunks in use: 0. 0B allocated for chunks. 0B in use in bin. 0B client-requested in use in bin.
2022-12-01 12:44:29.498379: I tensorflow/tsl/framework/bfc_allocator.cc:1041] Bin (8192): 	Total Chunks: 0, Chunks in use: 0. 0B allocated for chunks. 0B in use in bin. 0B client-requested in use in bin.
2022-12-01 12:44:29.498388: I tensorflow/tsl/framework/bfc_allocator.cc:1041] Bin (16384): 	Total Chunks: 0, Chunks in use: 0. 0B allocated for chunks. 0B in use in bin. 0B client-requested in use in bin.
2022-12-01 12:44:29.498398: I tensorflow/tsl/framework/bfc_allocator.cc:1041] Bin (32768): 	Total Chunks: 0, Chunks in use: 0. 0B allocated for chunks. 0B in use in bin. 0B client-requested in use in bin.
2022-12-01 12:44:29.498407: I tensorflow/tsl/framework/bfc_allocator.cc:1041] Bin (65536): 	Total Chunks: 0, Chunks in use: 0. 0B allocated for chunks. 0B in use in bin. 0B client-requested in use in bin.
2022-12-01 12:44:29.498417: I tensorflow/tsl/framework/bfc_allocator.cc:1041] Bin (131072): 	Total Chunks: 0, Chunks in use: 0. 0B allocated for chunks. 0B in use in bin. 0B client-requested in use in bin.
2022-12-01 12:44:29.498426: I tensorflow/tsl/framework/bfc_allocator.cc:1041] Bin (262144): 	Total Chunks: 0, Chunks in use: 0. 0B allocated for chunks. 0B in use in bin. 0B client-requested in use in bin.
2022-12-01 12:44:29.498435: I tensorflow/tsl/framework/bfc_allocator.cc:1041] Bin (524288): 	Total Chunks: 0, Chunks in use: 0. 0B allocated for chunks. 0B in use in bin. 0B client-requested in use in bin.
2022-12-01 12:44:29.498445: I tensorflow/tsl/framework/bfc_allocator.cc:1041] Bin (1048576): 	Total Chunks: 0, Chunks in use: 0. 0B allocated for chunks. 0B in use in bin. 0B client-requested in use in bin.
2022-12-01 12:44:29.498454: I tensorflow/tsl/framework/bfc_allocator.cc:1041] Bin (2097152): 	Total Chunks: 0, Chunks in use: 0. 0B allocated for chunks. 0B in use in bin. 0B client-requested in use in bin.
2022-12-01 12:44:29.498464: I tensorflow/tsl/framework/bfc_allocator.cc:1041] Bin (4194304): 	Total Chunks: 0, Chunks in use: 0. 0B allocated for chunks. 0B in use in bin. 0B client-requested in use in bin.
2022-12-01 12:44:29.498473: I tensorflow/tsl/framework/bfc_allocator.cc:1041] Bin (8388608): 	Total Chunks: 0, Chunks in use: 0. 0B allocated for chunks. 0B in use in bin. 0B client-requested in use in bin.
2022-12-01 12:44:29.498482: I tensorflow/tsl/framework/bfc_allocator.cc:1041] Bin (16777216): 	Total Chunks: 0, Chunks in use: 0. 0B allocated for chunks. 0B in use in bin. 0B client-requested in use in bin.
2022-12-01 12:44:29.498492: I tensorflow/tsl/framework/bfc_allocator.cc:1041] Bin (33554432): 	Total Chunks: 0, Chunks in use: 0. 0B allocated for chunks. 0B in use in bin. 0B client-requested in use in bin.
2022-12-01 12:44:29.498501: I tensorflow/tsl/framework/bfc_allocator.cc:1041] Bin (67108864): 	Total Chunks: 0, Chunks in use: 0. 0B allocated for chunks. 0B in use in bin. 0B client-requested in use in bin.
2022-12-01 12:44:29.498511: I tensorflow/tsl/framework/bfc_allocator.cc:1041] Bin (134217728): 	Total Chunks: 0, Chunks in use: 0. 0B allocated for chunks. 0B in use in bin. 0B client-requested in use in bin.
2022-12-01 12:44:29.498520: I tensorflow/tsl/framework/bfc_allocator.cc:1041] Bin (268435456): 	Total Chunks: 0, Chunks in use: 0. 0B allocated for chunks. 0B in use in bin. 0B client-requested in use in bin.
2022-12-01 12:44:29.498533: I tensorflow/tsl/framework/bfc_allocator.cc:1057] Bin for 1.43GiB was 256.00MiB, Chunk State: 
2022-12-01 12:44:29.498541: I tensorflow/tsl/framework/bfc_allocator.cc:1095]      Summary of in-use Chunks by size: 
2022-12-01 12:44:29.498549: I tensorflow/tsl/framework/bfc_allocator.cc:1102] Sum Total of in-use chunks: 0B
2022-12-01 12:44:29.498557: I tensorflow/tsl/framework/bfc_allocator.cc:1104] total_region_allocated_bytes_: 0 memory_limit_: 924254208 available bytes: 924254208 curr_region_allocation_bytes_: 924254208
2022-12-01 12:44:29.498580: I tensorflow/tsl/framework/bfc_allocator.cc:1110] Stats: 
Limit:                       924254208
InUse:                               0
MaxInUse:                            0
NumAllocs:                           0
MaxAllocSize:                        0
Reserved:                            0
PeakReserved:                        0
LargestFreeBlock:                    0

2022-12-01 12:44:29.498599: W tensorflow/tsl/framework/bfc_allocator.cc:492] <allocator contains no memory>

So it still seems that TensorFlow tries to load the complete tensor onto the GPU instead of a batch.

Any suggestions?
GW

I’m having the same problem. I have a large (5.7GB) NumPy array called train_x. When I try to create a tf.data.Dataset to contain it, GPU memory runs out (I only have a 2GB GeForge MX150).

Here is my minimal code to demonstrate the problem:

import numpy as np
import tensorflow as tf
from tensorflow.data import Dataset

train_x = np.load('train_x.npy')  # 5.7GB, shape=(51182, 100, 300), dtype=float32
train_x_tensor = tf.convert_to_tensor(train_x)
train_dataset = Dataset.from_tensor_slices(train_x_tensor)

This fails with the error similar to above discussion:

2022-12-10 22:58:56.614124: W tensorflow/core/common_runtime/bfc_allocator.cc:479] Allocator (GPU_0_bfc) ran out of memory trying to allocate 5.72GiB (rounded to 6141840128)requested by op _EagerConst

This happens already when I create the Dataset, so there is no chance to use .batch().

I also tried mmap’ing the NumPy arrays to avoid loading them entirely into CPU memory, but it makes no difference.

I’m using tensorflow-gpu 2.10.0 installed from conda-forge.

1 Like

I had the same issue (see previous comment) but I found a solution in this thread thanks to advice from @markdaoust - the solution was to create the tensor (that goes into Dataset) on the CPU. So instead of this:

you could try this:

with tf.device("CPU"):
    train = Dataset.from_tensor_slices((x_train, y_train)).shuffle(4*128).batch(128)
    validate = Dataset.from_tensor_slices((x_test, y_test)).batch(128)

At least it helped in my case; the large dataset is being stored on the CPU side (in system RAM) and only sent in batches to the GPU, not all at once.

Hope this helps,
Osma

1 Like

Hi Osma,

Yes, that resolved the error, thank you, but where can you now see that the GPU is being used (so that with tf.device(“CPU”) does not trigger the CPU being used for training instead of the GPU)?

Thanks,
GW

1 Like

FYI, you can check if the GPU is being used by running nvidia-smi on the host before you start model.fit and during model.fit. You well see GPU-Util going up and a /usr/bin/python3 process being created that uses a certain amount of memory.

Regards,
GW

1 Like

Right, that’s what I did as well. You can follow GPU utilization via nvidia-smi -l (the -l switch makes it loop forever until you stop it) and CPU utilization for example using top. It’s quite obvious that the GPU is being used, CPU usage remains quite low.

Also, the tf.device uses a python context manager ( “with” statement ) to make it clear that it only applies to that block, and nothing else.

1 Like

None of the solutions proposed above worked for me.

What worked (with TF 2.4) was changing the data loading of the tf.data.Dataset. Specifically, I switched from using from_tensor_slices to using from_generator. I am tackling semantic segmentation with volumes of shape 64x64x64. Here’s some pseudo code:

input_volumes_list = [...]  # list containing the input volumes that have shape 64x64x64
input_masks_list = [...]  # list containing the corresponding segmentation masks also of shape 64x64x64

# define generator function
def generator_images_and_masks():
    for idx in range(len(input_volumes_list)):
        # extract one image and the corresponding mask
         img = input_volumes_list[idx]
         mask = input_masks_list[idx]

         # convert to TF tensors
         img_tensor = tf.convert_to_tensor(img, dtype=tf.float32)
         mask_tensor = tf.convert_to_tensor(mask, dtype=tf.float32)

         yield img_tensor, mask_tensor

# create dataset using generator function and specifying shapes and dtypes
dataset = tf.data.Dataset.from_generator(generator_images_and_masks, 
                                         output_signature=(tf.TensorSpec(shape=(64, 64, 64), dtype=tf.float32),
                                                           tf.TensorSpec(shape=(64, 64, 64), dtype=tf.float32)))


When doing this I don’t get the GPU memory error anymore

Thanks for this discussion, it helped me also solve my memory issues.
It seems that, at least up until in TF 2.10, getting the next element of a dataset created through from_tensor_slices causes the whole dataset to be copied, which wont fit into my RAM.

I tested this with elem = next(iter(train_dataset)) - which crashes the kernel for allocating too much memeroy for tensor slices, but work perfectly fine with the generator.

Interestingly the generator is MUCH slower, but it also works with .shuffle() (on the full dataset length) and with .cache() which makes the subsequent epochs somewhat as fast as the from_tensor_slices version. I would have expected one or both of these operations to again cause memory issues, but it somehow doesn’t.

If someone could explain what is going on, and if there is a possibility to avoid the first epoch being extremely slow (200s vs <20s), I’d appreciate it.