tf.data.Dataset with tf.distribute

Hello, tensorflow community. I’m trying to run training a bi-directional LSTM model on a machine with 8 A100 GPUs. For this I use tf.distribute.MirroredStrategy. My dataset is huge and consists of millions of samples. I initially load the dataset as a numpy array. I’m trying to follow the official documentation and create tf.data.Dataset.from_tensor_slices() but I get a memory overflow error. I don’t know why, but apparently tf is trying to load the entire dataset on the GPU (and apparently on one GPU) and then convert it to Datasets. I decided to create a tf.data.Dataset like this with tf.device('CPU'): dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
And it worked. Also, when creating the dataset, I specified the batch size equal to the global batch. After starting the training, I noticed that it was somehow slow. I decided to downsize the original dataset in order to test distributed learning without creating a tf.data.Dataset
Indeed, if numpy arrays are passed to the fit method of the model, then training is 2-3 times faster compared to tf.dataDataset. Tell me why is this happening? What are the best practices when using tf.data.Datasets in distributed learning. Since in order to train the model on all my data, I still have to convert from to a generator, otherwise I get an error about GPU memory overflow.