Problems with training a model on a dataset that doesn't fit into RAM memory

Hi everyone!

I’m a student who’s trying to get the python code of a model from a research paper to work. I’ve managed to tweak the code such that it is suitable for my own dataset and such that it can train on a very small dataset without generating any erros, showing that the code itself is fine. My dataset, however, as a whole, is too large to fit into RAM memory, it’s a dataset of 1,2T. Even splitting everything in batches of 5 files results in batches of 10 GB, as I’m working with multicoil mri data. This batch size in and of itself, however, is fine, as I have 40GB of RAM available during busy periods and if it’s calm (like now), I can even use 100GB of RAM. This means that I would want to train by loading in these batches, which I stored as separate .npy files, on the fly.

When conducting some research, I found that the .from_generator() class/function can help with this, so I tried to write code which can do this for me. These are the important excerpts of code that I wrote for the problem at hand:

# Other stuff preceding

## Prepare dataset

path_to_save_mri_data = '/usr/local/micapollo01/MIC/DATA/STUDENTS/mvhave7/Results/Preprocessing/mri/'
path_to_save_grappa_data = '/usr/local/micapollo01/MIC/DATA/STUDENTS/mvhave7/Results/Preprocessing/grappa/'


file_paths_train = sorted(glob.glob(path_to_save_mri_data+"training_data_GrappaNet_16_coils_batch_*.npy"))
file_paths_train_GT = sorted(glob.glob(path_to_save_mri_data+"training_data_GT_GrappaNet_16_coils_batch_*.npy"))
file_paths_val = sorted(glob.glob(path_to_save_mri_data+"validation_data_GrappaNet_16_coils_batch_*.npy"))
file_paths_val_GT = sorted(glob.glob(path_to_save_mri_data+"validation_data_GT_GrappaNet_16_coils_batch_*.npy"))
file_paths_grappa_indx_train = sorted(glob.glob(path_to_save_grappa_data+"grappa_train_indx_GrappaNet_16_coils_batch_*.npy"))
file_paths_grappa_indx_val = sorted(glob.glob(path_to_save_grappa_data+"grappa_validation_indx_GrappaNet_16_coils_batch_*.npy"))
file_paths_grappa_wt = sorted(glob.glob(path_to_save_grappa_data+"grappa_wt_batch_*.pickle"))
file_paths_grappa_p = sorted(glob.glob(path_to_save_grappa_data+"grappa_p_batch_*.pickle"))


def train_generator(file_paths_train, file_paths_train_GT, file_paths_grappa_indx_train, file_paths_grappa_wt, file_paths_grappa_p):
    global grappa_wt
    global grappa_p
    for file_path_train, file_path_train_GT, file_path_grappa_indx_train, file_path_grappa_wt, file_path_grappa_p in zip (file_paths_train, file_paths_train_GT, file_paths_grappa_indx_train, file_paths_grappa_wt, file_paths_grappa_p):
        x_train = np.load(file_path_train)
        y_train = np.load(file_path_train_GT)
        grappa_train_indx = np.load(file_path_grappa_indx_train)
        with open(file_path_grappa_wt, 'rb') as handle:
            grappa_wt = pickle.load(handle)
        with open(file_path_grappa_p, 'rb') as handle:
            grappa_p = pickle.load(handle)
        
        yield ((x_train, grappa_train_indx), y_train)

        del x_train, y_train, grappa_train_indx
        time.sleep(1)
        gc.collect()
        time.sleep(1)


def validation_generator(file_paths_val, file_paths_val_GT, file_paths_grappa_indx_val, file_paths_grappa_wt, file_paths_grappa_p):
    global grappa_wt
    global grappa_p
    for file_path_val, file_path_val_GT, file_path_grappa_indx_val, file_path_grappa_wt, file_path_grappa_p in zip(file_paths_val, file_paths_val_GT, file_paths_grappa_indx_val, file_paths_grappa_wt, file_paths_grappa_p):
        x_test = np.load(file_path_val)
        y_test = np.load(file_path_val_GT)
        grappa_test_indx = np.load(file_path_grappa_indx_val)
        with open(file_path_grappa_wt, 'rb') as handle:
            grappa_wt = pickle.load(handle)
        with open(file_path_grappa_p, 'rb') as handle:
            grappa_p = pickle.load(handle)

        yield ((x_test, grappa_test_indx), y_test)

        del x_test, y_test, grappa_test_indx
        time.sleep(1)
        gc.collect()
        time.sleep(1)


print('Done. Setting up tensorflow structure to process in batches...')


## Create a .from_generator() object

training_dataset = tf.data.Dataset.from_generator(generator=lambda: train_generator(file_paths_train, file_paths_train_GT, file_paths_grappa_indx_train, file_paths_grappa_wt, file_paths_grappa_p), output_shapes=(((None, None, None, None), (None,)), (None, None, None)), output_types=((tf.float32, tf.int64), tf.float32))
validation_dataset = tf.data.Dataset.from_generator(generator=lambda: validation_generator(file_paths_val, file_paths_val_GT, file_paths_grappa_indx_val, file_paths_grappa_wt, file_paths_grappa_p), output_shapes=(((None, None, None, None), (None,)), (None, None, None)), output_types=((tf.float32, tf.int64), tf.float32))


# The model is built

## Train the model

model_name = "/usr/local/micapollo01/MIC/DATA/STUDENTS/mvhave7/Results/Models/best_model_GrappaNet.h5"

def step_decay(epoch, initial_lrate, drop, epochs_drop):
    return initial_lrate * math.pow(drop, math.floor((1+epoch)/float(epochs_drop)))

class ClearMemory(Callback):
    def on_epoch_end(self, epoch, logs=None):
        gc.collect()
        k.clear_session()

def get_callbacks(model_file, learning_rate_drop=0.7, learning_rate_patience=7, verbosity=1):
    callbacks = list()
    callbacks.append(ModelCheckpoint(model_file, save_best_only=True))
    callbacks.append(ReduceLROnPlateau(factor=learning_rate_drop, patience=learning_rate_patience, verbose=verbosity))
    callbacks.append(EarlyStopping(verbose=verbosity, patience=200))
    callbacks.append(ClearMemory())
    return callbacks

strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
    input_shape = (crop_size[1],crop_size[2],crop_size[0])
    epochs = 1  # In the original paper, 20 epochs were used
    batch_size = 8  # This batch size has unit 'number of slices', not 'number of files'
    model = build_model(input_shape)
    metrics = tf.keras.metrics.RootMeanSquaredError()
    model.compile(loss=model_loss_ssim, optimizer=Adam(learning_rate=0.0003), metrics=[metrics], run_eagerly=True)
    #model.compile(loss=model_loss_ssim, optimizer=RMSprop(learning_rate=0.0003), metrics=[metrics])


history = model.fit(training_dataset,
            epochs=epochs,
            batch_size=batch_size,
            shuffle=False,
            validation_data=validation_dataset,
            callbacks=get_callbacks(model_name,0.6,10,1),
            max_queue_size=32,
            workers=100,
            use_multiprocessing=False)

However, when I run the script and training starts, I still see the RAM memory usage climbing up all the way above 100 GB, resulting in an std::bad_alloc() error. I’ve been researching and debugging for the past number of weeks, and I’m beginning to lose hope in getting this code to work :frowning:

Are there any people who spot problems in what I’m doing? Or who have other suggestions?

I’m using tensorflow 2.2.0 and switching to another deep learning python framework would be very tedious as the model I’m trying to test (the model from the research paper) is created using tensorflow.

Many, many thanks in advance! :pray:

Hi @normalguy19, Once you have loaded the dataset and converted it to a batch dataset, could you please try by pre fetching the dataset. For example, training_dataset = ds_train.prefetch(buffer_size=tf.data.AUTOTUNE).

I also recommend you to use the latest stable version (2.15) of tensorflow. tf.data.Dataset.from_generator was deprecated i recommend you to use tf.keras.utils for loading the dataset. Thank You.

1 Like

Hi @Kiran_Sai_Ramineni, thanks a lot for the very swift reply! I will look into prefetching the dataset.

When it comes to upgrading the tensorflow version: Unfortunately, upgrading my tensorflow version gives dependency issues with other packages used in the code of the research paper, so I just stuck with tensorflow 2.2.0 which is also the version used in the docker file of the original code.

Hi, I just tried implementing your suggestion and unfortunately it didn’t seem to help. When running the script, I still see the RAM surging to 100+GB followed by a termination. I’ve added the output of the script below, as I now seem to get a different termination reason (Segmentation failed). Additionally, I’ve added my modified code. Note that I couldn’t use tf.data.AUTOTUNE as this is only available in tensorflow 2.3.0+, so I just used a manually chosen batch size which should definitely fit into memory.

It would mean a lot if you would like to take another look at this. Many hopeful thanks!

Script output:

Use CPU socket 1
Resource limit set. Importing libraries...
Libraries imported. Starting to prepare the dataset...
Done. Setting up tensorflow structure to process in batches...
2023-11-29 23:34:41.793058: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcuda.so.1
2023-11-29 23:34:45.642500: E tensorflow/stream_executor/cuda/cuda_driver.cc:313] failed call to cuInit: CUDA_ERROR_COMPAT_NOT_SUPPORTED_ON_DEVICE: forward compatibility was attempted on non supported HW
2023-11-29 23:34:45.642590: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:169] retrieving CUDA diagnostic information for host: micsd01
2023-11-29 23:34:45.642602: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:176] hostname: micsd01
2023-11-29 23:34:45.642769: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:200] libcuda reported version is: 525.147.5
2023-11-29 23:34:45.642812: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:204] kernel reported version is: 525.125.6
2023-11-29 23:34:45.642821: E tensorflow/stream_executor/cuda/cuda_diagnostics.cc:313] kernel version 525.125.6 does not match DSO version 525.147.5 -- cannot find working devices in this configuration
2023-11-29 23:34:45.643354: I tensorflow/core/platform/cpu_feature_guard.cc:143] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
2023-11-29 23:34:45.652823: I tensorflow/core/platform/profile_utils/cpu_utils.cc:102] CPU Frequency: 2600090000 Hz
2023-11-29 23:34:45.654193: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x7f540c000b60 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2023-11-29 23:34:45.654213: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Host, Default Version
Done. Building the GrappaNet model architecture...
Done. Training the model...
Segmentation fault (core dumped)

Modified code:

# Imports and other stuff that's not important
# ...

batch_size = 8
# This batch size has unit 'number of slices', not 'number of files'


print('Libraries imported. Starting to prepare the dataset...')


## Prepare dataset

path_to_save_mri_data = '/usr/local/micapollo01/MIC/DATA/STUDENTS/mvhave7/Results/Preprocessing/mri/'
path_to_save_grappa_data = '/usr/local/micapollo01/MIC/DATA/STUDENTS/mvhave7/Results/Preprocessing/grappa/'


file_paths_train = sorted(glob.glob(path_to_save_mri_data+"training_data_GrappaNet_16_coils_batch_*.npy"))
file_paths_train_GT = sorted(glob.glob(path_to_save_mri_data+"training_data_GT_GrappaNet_16_coils_batch_*.npy"))
file_paths_val = sorted(glob.glob(path_to_save_mri_data+"validation_data_GrappaNet_16_coils_batch_*.npy"))
file_paths_val_GT = sorted(glob.glob(path_to_save_mri_data+"validation_data_GT_GrappaNet_16_coils_batch_*.npy"))
file_paths_grappa_indx_train = sorted(glob.glob(path_to_save_grappa_data+"grappa_train_indx_GrappaNet_16_coils_batch_*.npy"))
file_paths_grappa_indx_val = sorted(glob.glob(path_to_save_grappa_data+"grappa_validation_indx_GrappaNet_16_coils_batch_*.npy"))
file_paths_grappa_wt = sorted(glob.glob(path_to_save_grappa_data+"grappa_wt_batch_*.pickle"))
file_paths_grappa_p = sorted(glob.glob(path_to_save_grappa_data+"grappa_p_batch_*.pickle"))


def train_generator(file_paths_train, file_paths_train_GT, file_paths_grappa_indx_train, file_paths_grappa_wt, file_paths_grappa_p):
    global grappa_wt
    global grappa_p
    for file_path_train, file_path_train_GT, file_path_grappa_indx_train, file_path_grappa_wt, file_path_grappa_p in zip (file_paths_train, file_paths_train_GT, file_paths_grappa_indx_train, file_paths_grappa_wt, file_paths_grappa_p):
        x_train = np.load(file_path_train)
        y_train = np.load(file_path_train_GT)
        grappa_train_indx = np.load(file_path_grappa_indx_train)
        with open(file_path_grappa_wt, 'rb') as handle:
            grappa_wt = pickle.load(handle)
        with open(file_path_grappa_p, 'rb') as handle:
            grappa_p = pickle.load(handle)
        
        yield ((x_train, grappa_train_indx), y_train)

def validation_generator(file_paths_val, file_paths_val_GT, file_paths_grappa_indx_val, file_paths_grappa_wt, file_paths_grappa_p):
    global grappa_wt
    global grappa_p
    for file_path_val, file_path_val_GT, file_path_grappa_indx_val, file_path_grappa_wt, file_path_grappa_p in zip(file_paths_val, file_paths_val_GT, file_paths_grappa_indx_val, file_paths_grappa_wt, file_paths_grappa_p):
        x_test = np.load(file_path_val)
        y_test = np.load(file_path_val_GT)
        grappa_test_indx = np.load(file_path_grappa_indx_val)
        with open(file_path_grappa_wt, 'rb') as handle:
            grappa_wt = pickle.load(handle)
        with open(file_path_grappa_p, 'rb') as handle:
            grappa_p = pickle.load(handle)

        yield ((x_test, grappa_test_indx), y_test)


print('Done. Setting up tensorflow structure to process in batches...')


## Create a .from_generator() object

training_dataset = tf.data.Dataset.from_generator(generator=lambda: train_generator(file_paths_train, file_paths_train_GT, file_paths_grappa_indx_train, file_paths_grappa_wt, file_paths_grappa_p), output_shapes=(((None, None, None, None), (None,)), (None, None, None)), output_types=((tf.float32, tf.int64), tf.float32))
validation_dataset = tf.data.Dataset.from_generator(generator=lambda: validation_generator(file_paths_val, file_paths_val_GT, file_paths_grappa_indx_val, file_paths_grappa_wt, file_paths_grappa_p), output_shapes=(((None, None, None, None), (None,)), (None, None, None)), output_types=((tf.float32, tf.int64), tf.float32))


## Add pre-fetch

training_dataset = training_dataset.prefetch(buffer_size=batch_size)
validation_dataset = validation_dataset.prefetch(buffer_size=batch_size)


print('Done. Building the GrappaNet model architecture...')


# Build the model
# ...

## Train the model

model_name = "/usr/local/micapollo01/MIC/DATA/STUDENTS/mvhave7/Results/Models/best_model_GrappaNet.h5"

def step_decay(epoch, initial_lrate, drop, epochs_drop):
    return initial_lrate * math.pow(drop, math.floor((1+epoch)/float(epochs_drop)))

class ClearMemory(Callback):
    def on_epoch_end(self, epoch, logs=None):
        gc.collect()
        k.clear_session()

def get_callbacks(model_file, learning_rate_drop=0.7, learning_rate_patience=7, verbosity=1):
    callbacks = list()
    callbacks.append(ModelCheckpoint(model_file, save_best_only=True))
    callbacks.append(ReduceLROnPlateau(factor=learning_rate_drop, patience=learning_rate_patience, verbose=verbosity))
    callbacks.append(EarlyStopping(verbose=verbosity, patience=200))
    callbacks.append(ClearMemory())
    return callbacks

strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
    input_shape = (crop_size[1],crop_size[2],crop_size[0])
    epochs = 1  # In the original paper, 20 epochs were used
    batch_size = 8  # This batch size has unit 'number of slices', not 'number of files'
    model = build_model(input_shape)
    metrics = tf.keras.metrics.RootMeanSquaredError()
    model.compile(loss=model_loss_ssim, optimizer=Adam(learning_rate=0.0003), metrics=[metrics], run_eagerly=True)
    #model.compile(loss=model_loss_ssim, optimizer=RMSprop(learning_rate=0.0003), metrics=[metrics])


history = model.fit(training_dataset,
            epochs=epochs,
            batch_size=batch_size,
            shuffle=False,
            validation_data=validation_dataset,
            callbacks=get_callbacks(model_name,0.6,10,1),
            max_queue_size=32,
            workers=100,
            use_multiprocessing=False)