Input ran out of data interrupting training

I am working in a tumor segmentation, the images are 3D in nifti format (MRI Images). I created a data generator because if I uploaded the full dataset to RAM, it crashes due to the 3D images. The dataset is composed by 611 images and its dimensions are (240, 240, 160), the number of patches when I calculated them are 61000 and 11375

Here is my data pipeline:


def load_nifti_image(filepath, patch_size=(48, 48, 32), step_size=(48, 48, 32)):
    nifti = nib.load(filepath)
    volume = nifti.get_fdata()

    # Create patches from the volume
    patches = patchify(volume, patch_size, step=step_size)

    # Reshape patches multiplying (5, 5, 5) and add channel dimension (1 for grayscale)
    patches = patches.reshape(-1, *patches.shape[-3:])
    patches = np.expand_dims(patches, axis=-1)

    return patches

# -----------------------TRAIN-----------------------
nifti_files = [os.path.join("/content/drive/MyDrive/Interpolated/train/images", f) for f in os.listdir("/content/drive/MyDrive/Interpolated/train/images") if f.endswith('.nii.gz')]
mask_files = [os.path.join("/content/drive/MyDrive/Interpolated/train/masks", f) for f in os.listdir("/content/drive/MyDrive/Interpolated/train/masks") if f.endswith('.nii.gz')]

 # -----------------------VALIDATION-----------------------
nifti_files_val = [os.path.join("/content/drive/MyDrive/Interpolated/validation/images", f) for f in os.listdir("/content/drive/MyDrive/Interpolated/validation/images") if f.endswith('.nii.gz')]
mask_files_val = [os.path.join("/content/drive/MyDrive/Interpolated/validation/masks", f) for f in os.listdir("/content/drive/MyDrive/Interpolated/validation/masks") if f.endswith('.nii.gz')]

def calculate_patches(filepath, patch_size=(48, 48, 32), step_size=(48, 48, 32)):
    nifti = nib.load(filepath)
    volume = nifti.get_fdata()

    # Calculate the number of patches
    patches_shape = [((i - p) // s) + 1 for i, p, s in zip(volume.shape, patch_size, step_size)]
    num_patches = np.prod(patches_shape)

    return num_patches

num_train_patches = sum(calculate_patches(f) for i, f in enumerate(nifti_files) if print(f"Processing file {i}...") is None)
num_val_patches = sum(calculate_patches(f) for i, f in enumerate(nifti_files_val) if print(f"Processing file {i}...") is None)

def data_generator(image_files, mask_files):
    for img_file, mask_file in zip(image_files, mask_files):
        image_patches = load_nifti_image(img_file)
        mask_patches = load_nifti_image(mask_file)

        for img_patch, mask_patch in zip(image_patches, mask_patches):
            yield img_patch, mask_patch

train_generator = data_generator(nifti_files, mask_files)
val_generator = data_generator(nifti_files_val, mask_files_val)

output_signature = (
    tf.TensorSpec(shape=(48, 48, 32, 1), dtype=tf.float64),
    tf.TensorSpec(shape=(48, 48, 32, 1), dtype=tf.float64)
)

dataset = tf.data.Dataset.from_generator(lambda: train_generator, output_signature=output_signature).repeat()
dataset_val = tf.data.Dataset.from_generator(lambda: val_generator, output_signature=output_signature).repeat()

dataset = dataset.batch(32)
dataset_val = dataset_val.batch(32)

test_model.fit(dataset, validation_data=dataset_val, epochs=100, steps_per_epoch=num_train_patches//32, validation_steps=num_val_patches//32)

I added the .repeat() hoping that it would help, but it did not. I also try calculating the number of patches and inserting them manualy but it did not work either. This is the complete traceback:

Epoch 1/100
1906/1906 [==============================] - 1963s 1s/step - loss: 0.6447 - dice_coefficient: 0.3553 - val_loss: 0.9113 - val_dice_coefficient: 0.0887
Epoch 2/100
   1/1906 [..............................] - ETA: 10:17 - loss: 1.0000 - dice_coefficient: 1.7961e-05

WARNING:tensorflow:Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least `steps_per_epoch * epochs` batches (in this case, 190600 batches). You may need to use the repeat() function when building your dataset.
WARNING:tensorflow:Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least `steps_per_epoch * epochs` batches (in this case, 355 batches). You may need to use the repeat() function when building your dataset.

1906/1906 [==============================] - 0s 31us/step - loss: 1.0000 - dice_coefficient: 1.7961e-05

<keras.src.callbacks.History at 0x7a793461a590>

Hi @matca, If possible could you please share the sample data to reproduce the warnings. Thank You.