Split make_csv_dataset batches intro a train and validation set?

Following up on my ‘Solved: Abalone Shell Load CSV batch input from dataset’ topic:

abalone_csv_ds = tf.data.experimental.make_csv_dataset(
    column_names=["Length", "Diameter", "Height", "Whole weight", "Shucked weight",
           "Viscera weight", "Shell weight", "Age"],
    batch_size=10, # Artificially small to make examples easier to show.

def pack(features, label):
  return tf.stack(list(features.values()), axis=-1), label

packed_dataset = abalone_csv_ds.map(pack)

Model.fit(packed_dataset, ..)

As make_csv_dataset does not seem to support split: I guess packed_dataset is not split automatically into a training and test set? Any suggestions on how to split the CSV batches into a train and validation set?


Check out tf.data.Dataset.take. Link attached here


Thank you. take(1) seems to return one batch of 10 records (or 100 or 1000, whatever you specify for batch_size), so something like this that I found does not work as DATASET_SIZE is equal to 1:

train_size = int(0.7 * DATASET_SIZE)
val_size = int(0.15 * DATASET_SIZE)
test_size = int(0.15 * DATASET_SIZE)

full_dataset = tf.data.TFRecordDataset(FLAGS.input_file)
full_dataset = full_dataset.shuffle()
train_dataset = full_dataset.take(train_size)
test_dataset = full_dataset.skip(train_size)
val_dataset = test_dataset.skip(test_size)
test_dataset = test_dataset.take(test_size)

According to the documentation take() returns a dataset, but take(1) does not have a split() method. Any suggestions?


I see 2 ways you can do your split.

  1. Do the splitting before you batch your dataset (as in the example code)
  2. Use the take() and skip() methods to select the various batches you need for training, validation and testing. This works for both batched and unbatched data.

The example code you provided works best if your dataset has not been batched yet which is why it is not working with the current state of your dataset.


When applied to batches the take() and skip() methods return/skip one batch per take. Suppose you would like to use an 80/20 split you would have to take(8), skip(8) and take(2) batches which becomes an issue near the end of the dataset as not enough batches might be available.

Continuing from the packed_dataset above, the following code trains a model per batch, splitting each batch in a train and validation set using train_test_split from scikit-learn:

for epoch in range(200):
    print('epoch=', epoch)
    for features, labels in packed_dataset:
        x_train, x_test, y_train, y_test =
            train_test_split(features.numpy(), labels.numpy(), test_size=0.2)
        model.fit(x_train, y_train,
                  validation_data=(x_test, y_test))