Understanding a Tensorflow data structure

I am looking at a kaggle code which is building a tf.data.Dataset object and then splits the data that was placed in the object for further processing. I am having a hard time understanding what is happening to the data in the following two steps:

  1. Creating the datset object with a tuple made of data and labels
  2. Splitting the dataset object again into a tuple of (tuple and a list).

Please see below:

def split_labels(x, y):
    return (x[0], x[1]), y

t_dataset = (
            keras.utils.to_categorical(df_train['label'], num_classes=3)

x_preprocessed = t_dataset.map(split_labels)

Do I understand correctly that the only difference between the data structure before the call to split_labels and after is that :

before the call the data structure is a Tuple made up of two Lists
after the call the data structure is a Tuple made of a Tuple and a List?

Thank you

I don’t have hands on experience with Keras, but below is my thoughts which you might find helpful.

def split_labels() is high order function aka lambda expression which is applied to each element of origin dataset. .map() operation do this.

from_tensor_slices reduce\modifies dimension of origin dataset. Here is a link on the official documentation. Dataset is indeed a Tuple, but first element of this tuple is 2-dimensional list and second element is 3-dimensional matrix. I can not predict what you will get after this operation, the best choice is to try to run it or implement unit test to understand its behaviour under different conditions.

After split_labels call on each element your final dataset increases its dimension on one element, on one tuple. No matter what was its operand.