Filtering plues interleaving a tf.dataset take hours

I’m filtering the dataset according to certain labels. Once I call the filtering method, everything is fine. But once I call next(iter(dataset))for certain values it gets processing for more the 12 hours - for other value it just give the result.

My filtering line code is:

   def balanced_dataset(dataset, labels_list, sample_size=1000):
        datasets_list = []
           for label in labels_list:
              print(f'Preparando o dataset {label}')
              locals()[label] = dataset.filter(lambda x, y: tf.greater(tf.reduce_sum(tf.cast(tf.equal(tf.constant(label, dtype=tf.int64), y), tf.float32)), tf.constant(0.)))
              datasets_list.append(locals()[label].take(sample_size))
          ds = tf.data.Dataset.from_tensor_slices(datasets_list)
          # 2. extract all elements from datasets and concat them into one dataset
          concat_ds = ds.interleave(lambda x: x, cycle_length=len(labels_list), num_parallel_calls=tf.data.AUTOTUNE, deterministic=False)   
    
    
        return concat_ds

Typically use them in calculations. A date vector contains the elements [year month day hour. (Removed by Moderator)

Sorry, you haven’t been clear enough.

I’m seeing similar behavior when I filter a vast majority of the dataset away. I have a column of 0s/1s, and I filter with dataset.filter(lambda input_dict: input_dict[pruning_feature_name] == 1). When I set only 1% of my dataset to have input_dict[pruning_feature_name] == 1, my batch retrieval time goes from .0006 seconds to 20+ seconds. We also use interleave, albeit before the filter step. Did you manage to figure anything out with this @Marlon_Henrique_Teix ?

@MaxMarion After all I decided to mine the relevant data, save it as CSV and then use tf.Data do load the data in a simpler manner. Now, I’m using the TFX library which has a more complex and better data pipeline for it. I suggest using it. Good lucky!

1 Like