Parallel data extraction with tf.data.Dataset.from_generator

I have a huge dataset (1TB) with thousand of small hdf5 files, each consisting out of two 3D numpy arrays (only float64 numbers), which currently are fetched by a generator which is given to the tf.data.Dataset.from_generator function. Since I cant cache my data, the data fetching process is quite slow. Now I want to use all my CPUs and parallel fetch from my dataset. Here is my code:

def generator(files):
    for file in files:
        with h5py.File(file, 'r') as hf:
            epsilon = hf['epsilon'][()]
            field = hf['field'][()]
        yield epsilon, field
 dataset = tf.data.Dataset.from_generator(pygen.generator, args=[files],output_signature=(
  tf.TensorSpec(shape=s[0], dtype=tf.float64),  tf.TensorSpec(shape=s[1], dtype=tf.float64)))

Is there a best-practise/solution to this problem?

Hi @munsteraner, You can consider using tf.data.Dataset.prefetch this allows later elements to be prepared while the current element is being processed. Thank You.

Yeah, I am also using prefetch, but that doesn’t work either. I guess my disc reading speed is limiting the extraction. Also I read that the tf.py_function only uses one core and I don’t know how I do thread it. Here is my code (fetches data, but no speed up):

def load_hdf5_file(filename):
    filename = tf.strings.reduce_join(filename, separator="").numpy().decode("utf-8")
    with h5py.File(filename, 'r') as hf:
        epsilon = hf['epsilon'][()]
        field = hf['field'][()]

    return epsilon, field




filenames = glob.glob(f'{args.p_data}/*.h5')

# Create a tf.data.Dataset using the filenames
dataset = tf.data.Dataset.from_tensor_slices(filenames)

# Map the load_hdf5_file function to load and extract arrays for each HDF5 file
num_parallel_calls = tf.data.AUTOTUNE

dataset = dataset.map(lambda x: tf.py_function(load_hdf5_file, [x], Tout=(tf.float32, tf.float32)), num_parallel_calls=num_parallel_calls)

spe = int(np.floor(len(filenames) / 32))
dataset = dataset.take(len(filenames)).batch(32).cache().repeat(3).prefetch(num_parallel_calls)  
m = var_unet_3D_test.build_unet((120,100,50,1))
m.compile(run_eagerly=True)
m.fit(dataset,epochs=args.bs, steps_per_epoch = spe)