tf.data.Dataset varies at re-iteration. Manual reset possible?

I wonder why repeatedly iterating over a tf.data.Dataset does not yield the same order each time (see the following code). Is there a way to reset the dataset manually, like it is reset for each training epoch?

This would be useful for manual model evaluation: I could reuse my val_dataset to call model.predict(val_dataset), then compare the predictions with the true classes in y. This way I could e.g. compute a Balanced Accuracy.

For testing, I have created a

synthetic dataset
import os
import numpy as np
import PIL

datapath = 'data-synthetic'
image_size = 224
for i_class in range(2):
    classpath = os.path.join(datapath, f'class_{i_class}')
    os.makedirs(classpath)
    for i_img in range(160): # Align with default batch size of 32.
        pixels = np.random.randint(i_class * 128, (i_class+1) * 128, (image_size, image_size))
        image = PIL.Image.fromarray(pixels.astype('uint8'), 'L').convert('RGB')
        image.save(os.path.join(classpath, f'{i_img:03d}.png'))

that is used in my test code:

import tensorflow as tf
from tensorflow.keras.preprocessing import image_dataset_from_directory

assert tf.__version__ == '2.9.1', f'Currently: {tf.__version__}' # Remember to run pip-sync

image_size = 224
def get_dataset(subset):
    print('get_dataset:', subset)
    return image_dataset_from_directory(
        'data-synthetic',
        labels="inferred",
        label_mode='binary',
        color_mode="rgb",
        batch_size=32,
        image_size=(image_size, image_size),
        shuffle=True,
        seed=1,
        validation_split=0.1,
        subset=subset,
        crop_to_aspect_ratio=False,
    )

val_dataset = get_dataset('validation')
for x,y in val_dataset: # Iteration yields the single batch.
    print(tf.transpose(y))

print('The dataset is reproducible:')
val_dataset = get_dataset('validation')
for x,y in val_dataset:
    print(tf.transpose(y))

print('... but not when just re-iterating:')
for x,y in val_dataset:
    print(tf.transpose(y))

The output shows that repeated calling of image_dataset_from_directory yields the same element order, but simple re-iteration doesn’t:

get_dataset: validation
Found 320 files belonging to 2 classes.
Using 32 files for validation.
tf.Tensor(
[[0. 0. 1. 1. 0. 1. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0. 1. 1. 0. 1. 1. 1. 0. 0.
  1. 1. 1. 0. 0. 1. 1. 0.]], shape=(1, 32), dtype=float32)
The dataset is reproducible:
get_dataset: validation
Found 320 files belonging to 2 classes.
Using 32 files for validation.
tf.Tensor(
[[0. 0. 1. 1. 0. 1. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0. 1. 1. 0. 1. 1. 1. 0. 0.
  1. 1. 1. 0. 0. 1. 1. 0.]], shape=(1, 32), dtype=float32)
... but not when just re-iterating:
tf.Tensor(
[[0. 0. 0. 0. 1. 0. 1. 1. 0. 0. 1. 0. 1. 1. 1. 1. 1. 0. 1. 0. 0. 0. 1. 0.
  1. 1. 0. 0. 1. 0. 1. 0.]], shape=(1, 32), dtype=float32)

“shuffle” means “return in a randomized order”. Try making this False.

@Lance_N thanks for the idea. However, I cannot set shuffle=False, because the validation split does not stratify. So the result would be a useless validation set consisting of only class_1.

I don’t know if you could find these two params useful:

Thanks for pointing out these Dataset.shuffle arguments, @Bhack! I have tried to use them by returning image_dataset_from_directory(...).shuffle(buffer_size=320, seed=1, reshuffle_each_iteration=False) in my get_dataset function. Unfortunately, this does not make the re-iteration stable.

I just noticed (and then also saw in a recent thread) that the upcoming TensorFlow 2.10.0 will have a tf.keras.utils.split_dataset function that I can use after calling image_dataset_from_directory with shuffle=False and without the validation_split argument, followed by your iteration-stable shuffling. This might do the job.

Until then, it seems I need another call to get_dataset('validation') each time I need another iteration.

I have now switched to tensorflow==2.10.0rc3 and split_dataset. Now even the original shuffling from image_dataset_from_directory is iteration-stable!

import tensorflow as tf
from tensorflow.keras.preprocessing import image_dataset_from_directory

assert tf.__version__ == '2.10.0-rc3', f'Currently: {tf.__version__}' # Remember to run pip-sync

image_size = 224
def get_dataset():
    print('get_dataset')
    return image_dataset_from_directory(
        'data-synthetic',
        labels="inferred",
        label_mode='binary',
        color_mode="rgb",
        batch_size=32,
        image_size=(image_size, image_size),
        shuffle=True,
        seed=1,
        crop_to_aspect_ratio=False,
    )

_, val_dataset = tf.keras.utils.split_dataset(get_dataset(), right_size=0.1)

for x,y in val_dataset: # Iteration yields the single batch.
    print(tf.transpose(y))

print('The dataset is reproducible:')
_, val_dataset = tf.keras.utils.split_dataset(get_dataset(), right_size=0.1)
for x,y in val_dataset:
    print(tf.transpose(y))

print('... also when just re-iterating:')
for x,y in val_dataset:
    print(tf.transpose(y))

yields:

get_dataset
Found 320 files belonging to 2 classes.
tf.Tensor(
[[1. 1. 1. 0. 0. 0. 1. 1. 1. 1. 0. 0. 1. 0. 0. 1. 0. 1. 1. 1. 0. 0. 0. 0.
  0. 0. 1. 0. 0. 0. 1. 0.]], shape=(1, 32), dtype=float32)
The dataset is reproducible:
get_dataset
Found 320 files belonging to 2 classes.
tf.Tensor(
[[1. 1. 1. 0. 0. 0. 1. 1. 1. 1. 0. 0. 1. 0. 0. 1. 0. 1. 1. 1. 0. 0. 0. 0.
  0. 0. 1. 0. 0. 0. 1. 0.]], shape=(1, 32), dtype=float32)
also when just re-iterating:
tf.Tensor(
[[1. 1. 1. 0. 0. 0. 1. 1. 1. 1. 0. 0. 1. 0. 0. 1. 0. 1. 1. 1. 0. 0. 0. 0.
  0. 0. 1. 0. 0. 0. 1. 0.]], shape=(1, 32), dtype=float32)
1 Like

This is most important for me because now model.predict(val_dataset) gives correctly ordered predictions!
I now found out that before TF 2.10.0, the wrong predictions always matched the above “second iteration” labels, even directly after image_dataset_from_directory.