Tensorflow dataset pick a sample of whole data

I have a code that generates an iterator from a Tensorflow dataset. The code is this:

def normalize_image(record):
  out = record.copy()
  out['image'] = tf.cast(out['image'], 'float32') / 255.
  return out

train_it = iter(tfds.builder('mnist').as_dataset(split='train').map(normalize_image).repeat().batch(256*10))

However, I want to do the manual splitting. For example, the MNISt dataset has 60000 training samples, but I want to only use the first 50000 (and hold others for validation). The problem is I don’t know how to do so.

I tried to convert it to NumPy and split based on that, but then I couldn’t apply the map to it.

ds_builder = tfds.builder('mnist')
train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
train_ds['image'] = train_ds['image'][0:50000, : , :]  
train_ds['label'] = train_ds['label'][0:50000]

I was wondering how to do so.

P.S: The ordering of data is also important for me, so I was thinking of loading all data in Numpy and saving the required ones in png and loading with tfds, but I’m not sure if it keeps the original order or not. I want to take the first 50000 samples of the whole 60000 samples.


Hi @Hossein_Arjomandi, You can get the first 50k, example from the mnist datset by

image, label = tfds.as_numpy(tfds.load(

Thank You.