Batching by batch_id in a Dataset object

Hey! I have a tensorflow.python.data.ops.dataset_ops.MapDataset object which I’m loading from TensorflowRecords.

It has three “columns”: encodings, label, batch_id.

I need to batch this dataset on the batch_id column.

For example, if I have 50 unique batch_ids , after the batch operation, the dataset will contain 50 batches, and each batch will have the data with the corresponding batch_id.

I’m trying to use the group_by_window function:

key_func=lambda elem: tf.cast(elem[‘batch_id’], tf.float32)
reduce_func=lambda key, window: window.batch(100000)
ds = ds.group_by_window(key_func = key_func, reduce_func = reduce_func, window_size = 10000)

but this throws an error: TypeError: () takes 1 positional argument but 3 were given

How to achieve this? Thanks!

Hi @Ofek_Levy, As requested i have implemented the code for batching the dataset based on batch_id with sample dataset, Below is the working code example:

import tensorflow as tf

data = [
    {"encodings": [0, 1, 2], "label": 0, "batch_id": 1},
    {"encodings": [3, 4, 5], "label": 1, "batch_id": 2},
    {"encodings": [6, 7, 8], "label": 0, "batch_id": 1},
    {"encodings": [9, 10, 11], "label": 1, "batch_id": 2},
]

dataset = tf.data.Dataset.from_generator(
    lambda: data,
    output_signature={
        "encodings": tf.TensorSpec(shape=(3,), dtype=tf.int32),
        "label": tf.TensorSpec(shape=(), dtype=tf.int32),
        "batch_id": tf.TensorSpec(shape=(), dtype=tf.int32),
    },
)

key_func=lambda elem: tf.cast(elem['batch_id'], tf.int64)
reduce_func=lambda key, window: window.batch(100000)
ds = dataset.group_by_window(key_func = key_func, reduce_func = reduce_func, window_size = 10000)

for elem in ds:
  print(elem)

#output:

{'encodings': <tf.Tensor: shape=(2, 3), dtype=int32, numpy=
array([[0, 1, 2],
       [6, 7, 8]], dtype=int32)>, 'label': <tf.Tensor: shape=(2,), dtype=int32, numpy=array([0, 0], dtype=int32)>, 'batch_id': <tf.Tensor: shape=(2,), dtype=int32, numpy=array([1, 1], dtype=int32)>}
{'encodings': <tf.Tensor: shape=(2, 3), dtype=int32, numpy=
array([[ 3,  4,  5],
       [ 9, 10, 11]], dtype=int32)>, 'label': <tf.Tensor: shape=(2,), dtype=int32, numpy=array([1, 1], dtype=int32)>, 'batch_id': <tf.Tensor: shape=(2,), dtype=int32, numpy=array([2, 2], dtype=int32)>}

Thank You.