How efficiently filter a specific number of entries and concatenating them in a unique tf.data.Dataset

I have a huge TFRecord file with more than 4M entries. It is a very unbalanced dataset containing many more entries of some labels and few others - compare to the whole dataset. I want to filter a limited number of entries of some of these labels in order to have a balanced dataset. Below, you can see my attempt, but it takes more than 24 hours to filter 1k from each label (33 different labels).

import tensorflow as tf

tf.compat.as_str(
    bytes_or_text='str', encoding='utf-8'
)
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect()
    print("Device:", tpu.master())
    strategy = tf.distribute.TPUStrategy(tpu)
except:
    strategy = tf.distribute.get_strategy()
print("Number of replicas:", strategy.num_replicas_in_sync)

ignore_order = tf.data.Options()
ignore_order.experimental_deterministic = False 
dataset = tf.data.TFRecordDataset('/test.tfrecord')
dataset = dataset.with_options(ignore_order)
features, feature_lists = detect_schema(dataset)

#Decodings TFRecord serialized data
def decode_data(serialized):
    X, y = tf.io.parse_single_sequence_example(
        serialized,
        context_features=features,
        sequence_features=feature_lists)
    return X['title'], y['subject']

dataset = dataset.map(lambda x: tf.py_function(func=decode_data, inp=[x], Tout=(tf.string, tf.string)))

#Filtering and concatenating the samples

def balanced_dataset(dataset, labels_list, sample_size=1000):
    datasets_list = []
    for label in labels_list:
        #Filtering the chosen labels
        locals()[label] = dataset.filter(lambda x, y: tf.greater(tf.reduce_sum(tf.cast(tf.equal(tf.constant(label, dtype=tf.int64), y), tf.float32)), tf.constant(0.)))
        #appending a limited sample
        datasets_list.append(locals()[label].take(sample_size))
    concat_dataset = datasets_list[0]
    #concatenating the datasets
    for dset in datasets_list[1:]:
        concat_dataset = concat_dataset.concatenate(dset)
    return concat_dataset  

balanced_data = balanced_dataset(tabledataset, labels_list=list(decod_dic.values()), sample_size=1000)