Dataset with tf.data with batches of randomly n chosen speakers and their m utterances

I’m trying to setup diarization system with speaker embedder for further clustering.

I’ve decided to use GE2E-XS loss, extension of GE2E loss, used in turn-to-diarize, which is mostly my inspiration. For this loss calculation I’d need batches of data in shape [n speakers x m utterances]. I think such random choice for batching is simple task when we operate on list of the files and I’ve found PyTorch implementation of such dataset and dataloader.

My problem is with the usage of tf.data.Dataset, which I build from TFRecords, because otherwise my training is input-bounded by network HDD drive with indexing issues (slow on many small files). I’d like to get different pairs with different utterances in each epoch.
Each of my tf.train.Example has such structure: {spectrogram, speaker_id, misc…}. Notice that each speaker has different amount of samples.

My idea was to:

  1. Load dataset from TFRecords and parse it;
  2. Shuffle it each epoch (to get different order of utterances every epoch);
  3. Filter the dataset in such way that produces dataset for each of the speakers Dataset = {dataset_speaker_0, dataset_speaker_1, …, dataset_speaker_k}. List of speakers is know and is loaded from CSV with sources that I use to build TFRecords;
  4. Shuffle the list of these speaker datasets (so I get different order and combinations of speakers in each epoch);
  5. Batch by somehow selected n speakers and iter through their dataset with m utterances as long as there is enough data left (Any unused data would come from speakers with less than m utterances left or the case where there is no n speakers left to select).

I’m not the expert in tf.data datasets by any means and previously I worked mostly with cases where each sample was independent. I’m a bit worried about complexity of such operations (how fast filtering the dataset into individual datsets for each speaker can go? Maybe I should directly prepare tfrecords for each speaker?). I have also no idea how to solve the step #5. I’ve seen choose_from_datasets method, but I’m still not sure how to use it in this scenario.

I’d appreciate your feedback on those two topics (memory+time complexity of filtering one big dataset + batching).

I came up with such a dataloader:

files = [str(wav_path) for wav_path in self._tfrecords_path.iterdir() if wav_path.suffix == ".tfrecord"]
dataset = tf.data.Dataset.from_tensor_slices(files)
datsaet = tf.data.TFRecordDataset(dataset)
dataset = dataset.map(tfrecords_reader, num_parallel_calls=AUTOTUNE)

# Filtering for speakers
speakers = tf.data.Dataset.from_tensor_slices(self._speaker_list)
dataset = speakers.map(
    lambda speaker_id: dataset.filter(lambda x, y: y["speaker_id"] == speaker_id)
    .cache()
    .shuffle(BUFFER_SIZE, reshuffle_each_iteration=True)
    .repeat(),
    num_parallel_calls=AUTOTUNE,
    )

datasets_idx = tf.data.Dataset.range(len(self._speaker_list))
        
# define those staticaly for uniform_candidate_sampler which requires static int
len_speakers = len(speakers)
len_pairs = len_speakers - 1

def pair_speakers(x):
    """Pair all speakers in dataset.
    Pairing is done in a way where each speaker is taken at least once.
    Also we can guarantee that those pairs are being re-generated each time.
    """

    to_pair = tf.reshape(tf.where(tf.range(len(speakers)) != x), [1, -1])  # Exclude speaker x from being drawed

    # sample n_speakers-1 pairings without replacement from len(speakers)-1 speakers
    pairs = tf.random.uniform_candidate_sampler(to_pair, len_pairs, self._num_speakers - 1, True, len_speakers)[0]
    paired = tf.concat([[x], pairs], 0)  # concat speaker x with its pairs

    # repeat each speaker count by selected number of utterances
    return tf.repeat(paired, tf.repeat(self._num_utters, self._num_speakers))  

choice_dataset = datasets_idx.map(
    pair_speakers, num_parallel_calls=AUTOTUNE
)  # [mapping of len(speakers) datasets, each num_speakers*num_utters]

choice_dataset = choice_dataset.flat_map(
    tf.data.Dataset.from_tensor_slices
)  # [single dataset of len(speakers)*num_speakers*num_utters,]

# deterministicly takes samples from filtered datasets with order specified in choice dataset
dataset = tf.data.Dataset.choose_from_datasets(dataset, choice_dataset).batch(
    self._num_speakers * self._num_utters
)

# apply transforms to batched dataset spectrograms. Faster transforms than on single examples
dataset.map(
    lambda x, y: (self._transform(x), y),
    num_parallel_calls=AUTOTUNE,
)
dataset = dataset.prefetch(AUTOTUNE)

This code is obviosuly related to whole dataloader class which I didn’t attach as it is not necessary for understanding.

In each iteration this dataloader loads different pairs of speakers, using different samples from them. Note it works different than stated in original question (not all recordings are used in each iteration).

My question is still how to filter the dataset with many speakers in a way that results with list of datasets that contain only samples from one speaker. My current implementation has to iterate over whole dataset as many times as there are speakers, which feels criminally unefficient, as this filtering could be done in a single pass over data.

I guess I need to look into tf.data pipeline profiling myself, as performance stuff is data-related. All I can tell is that the first pass takes ~9min on 10GB TFRecord. After each filtered dataset is cached, every next iteration takes only 2 seconds. Easiest solution I could imagine is storing TFRecords as already filtered, but maybe there is a simple and efficient solution with current implementation. I’ve also noticed similar issues with Triplet Loss and some people using tf.data.Dataset.interleave() with block_length to smartly load only n_utters from each file.

I think I might open another question in general section asking for help with this filtering. Felt like updating this post with my implementation so someone could possibly get some help when facing similar issue.