Keras model stops training after first step of epoch

I am new to deep learning but am looking to employ the 3D-RCAN neural network to denoise microscopy images. Unfortunately, I’m having trouble training this network properly on my computer. My machine’s hardware and software environments are:

  • Windows 10
  • Intel(R) Xeon(R) Gold 5220R CPU, 128 GB memory
  • NVIDIA RTX A6000 GPU, 48 GB memory
  • CUDA v11.2 with cuDNN v8.1.1.33
  • NVIDIA Driver 511.09
  • Python 3.9.7 in an Anaconda environment
  • tensorflow 2.8.0 and tensorflow-gpu 2.8.0

The network was written and tested in tensorflow_gpu==1.13.1. I have been able to train the network on my older computer, with a less powerful GPU, by reproducing the environment of the authors. However, I cannot do this on my my new computer with a more powerful GPU. By making some small modifications to the code, I’ve gotten the network to complete the training without an error message. I sometimes (but not always) receive the warning message

~\anaconda3\lib\site-packages\keras\engine\ RuntimeWarning: Failed to sample a valid patch."

It is clear that the network is not functioning properly- it appears to stop the training after the first step of the first epoch if I set the steps_per_epoch hyperparameter to be greater than 1. If steps_per_epoch=1, then the training does complete.

My major question, then, is what could be causing the network to stop training after the first step when steps_per_epoch>1? I’ve been looking into potential issues with the data_generator but haven’t been able to figure out whether this is the root cause. The generation is being done with a Keras sequence object to generate batches indefinitely.

A potential clue is that the GPU memory usage appears to spike to 100% just as the function is exiting. This is unexpected because this behavior does not occur on my old computer with far less GPU memory (6 GB).

The code appears to be breaking down in 'enumerate_epochs` in line 1194 of This is due to ‘self.insufficient_data’ being set in ‘catch_stop_iteration’. What might be the underlying reason it is running out of data?