Batching by batch_id in a Dataset object

Hey! I have a tensorflow.python.data.ops.dataset_ops.MapDataset object which I’m loading from TensorflowRecords.

It has three “columns”: encodings, label, batch_id.

I need to batch this dataset on the batch_id column.

For example, if I have 50 unique batch_ids , after the batch operation, the dataset will contain 50 batches, and each batch will have the data with the corresponding batch_id.

I’m trying to use the group_by_window function:

key_func=lambda elem: tf.cast(elem[‘batch_id’], tf.float32)
reduce_func=lambda key, window: window.batch(100000)
ds = ds.group_by_window(key_func = key_func, reduce_func = reduce_func, window_size = 10000)

but this throws an error: TypeError: () takes 1 positional argument but 3 were given

How to achieve this? Thanks!