TensorFlow Dataset reduce function too slow after skip

Hi everyone,
I’m trying to take 100 random elements from cifar10 dataset for each class and reduce them to a single “image” using the mean.
The problem is that the reduce function time increases significantly after I use the skip function with a large value.

The code is the following:
for c in classes:
skip_elem = 100
class_ds = train_ds.filter(lambda x, y: tf.equal(tf.argmax(y[0]), int(c))).skip(skip_elem).take(100).unbatch()
z = class_ds.reduce(tf.zeros(shape=(res//8, res//8, 4), dtype=tf.float32), lambda a, b: a + b[0])
z /= 100

I’ve tried to batch, rebatch, but it seems that if I skip few elements it executes the reduction in a reasonable time. The “take” function also increases the reduce time, but I expected that.

Why does it happen? There are other possibilities?

Thanks!