Hi everyone,

I’m trying to take 100 random elements from cifar10 dataset for each class and reduce them to a single “image” using the mean.

The problem is that the reduce function time increases significantly after I use the skip function with a large value.

The code is the following:

for c in classes:

skip_elem = 100

class_ds = train_ds.filter(lambda x, y: tf.equal(tf.argmax(y[0]), int(c))).skip(skip_elem).take(100).unbatch()

z = class_ds.reduce(tf.zeros(shape=(res//8, res//8, 4), dtype=tf.float32), lambda a, b: a + b[0])

z /= 100

I’ve tried to batch, rebatch, but it seems that if I skip few elements it executes the reduction in a reasonable time. The “take” function also increases the reduce time, but I expected that.

Why does it happen? There are other possibilities?

Thanks!