Validation set is not cached

I have a large training dataset of 400GB, a validation set of 50GB, and a RAM capacity of 950GB. Since these datasets are quite large, I’m using the from_generator function in TensorFlow’s tf.data.Dataset to fetch them using a generator. To improve performance, I’m caching the datasets using the tf.data.Dataset.cache() function. Without caching it takes way too long and wouldn’t be feasible.

While the caching works well for the training dataset, I’m experiencing issues with fetching the validation set. In each epoch, the validation set is loaded from the disk instead of being properly fetched and cached.

def generator(files, m):
    for file in files:
        with h5py.File(file, 'r') as hf:
            epsilon = hf['epsilon'][()]
            field = hf['field'][()]
            field = [scaler(x, m) for x in field]
        yield epsilon, field

def scaler(x, m):
    return (2*((x-m[0])/(m[1]-m[0])))-1

spe = int(np.floor(len(files) / args.bs))
vspe = int(np.floor(len(val_files)/ args.bs))

dataset = tf.data.Dataset.from_generator(pygen.generator, args=[files,minmax],output_signature=(
    tf.TensorSpec(shape=s[0], dtype=tf.float32),
    tf.TensorSpec(shape=s[1], dtype=tf.float32)))

val_dataset = tf.data.Dataset.from_generator(pygen.generator, args=[val_files,minmax],output_signature=(
    tf.TensorSpec(shape=s[0], dtype=tf.float32),
    tf.TensorSpec(shape=s[1], dtype=tf.float32)))

dataset = dataset.take(len(files)).cache().batch(args.bs).repeat(args.ep)
val_dataset = val_dataset.take(len(val_files)).cache().batch(args.bs).repeat(args.ep)

strategy = tf.distribute.MultiWorkerMirroredStrategy()
with strategy.scope():
   m = unet65.build(s[0])
   m.fit(dataset,validation_data=val_dataset, epochs=args.ep, steps_per_epoch = spe,validation_steps = vspe,callbacks=[model_checkpoint_callback,model_csv_logger,model_tensorboard,model_earlystopping_unet65])

Does anyone know how to accelerate the the validation set fetching? I also tried parallel fetching with num_parallel_calls=tf.data.AUTOTUNE and prefetch (enter link description here), but nothing seems to work fast enough.

Hi @munsteraner,

To further improve the speed of validation set fetching, you can try the following approaches:

1.Preload the Entire validation set into memory, since you have a substantial amount of RAM available
2. Utilizing the TFRecord format: If your validation dataset is composed of individual files, you can consider converting it into the TFRecord format. This format is optimized for reading large datasets efficiently.
3. Increase the cache size:This approach ensures that the validation dataset is cached on disk and can be efficiently reused in subsequent epochs.
4. Use a larger batch size for validation:A larger batch size can help reduce the frequency of disk I/O operations during validation.

I hope this helps!

Thanks.