GPU usage dips after each epoch to 0%

Hi everyone,
when training my model using model.fit() and using tf.data for my training and validation data the GPU usage dips to 0% after each epoch even though I am using the prefetch method for tf.data.Dataset.

Have you experienced something similar?
Sadly I cannot provide any code.

Thank you in advance.

My first two guesses would be:

  • The dataset needs to refill the shuffle buffer after each epoch, like model.fit(ds.shuffle(buffer_size).repeat()) instead of model.fit(ds.repeat().shuffle(), steps_per_epoch=N)

  • Maybe something with the evaluation logic?

Thank you for your reply.
Currently I am using

model.fit(train_data, epochs=self.epochs, validation_data=val_data, verbose=1)

Where train_data is a tf.data.Dataset with

train_data = tf.data.Dataset.from_tensor_slices((train_ivs, train_logr, train_metric))
train_data = train_data.shuffle(buffer_size=train_ivs.shape[0], seed=self.seed, 
                                reshuffle_each_iteration=True)
train_data = train_data.batch(self.batch_size)
train_data = train_data.prefetch(tf.data.AUTOTUNE)

The evaluation step does not seem to cause any problem either.