Out of memory issue with small model (500k parameters) and small to medium batch sizes

I am working with a comparatively small model with about 500k parameters. The model summary is

__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
==================================================================================================
 Input_left (InputLayer)     [(None, None)]               0         []                            
                                                                                                  
 Input_right (InputLayer)    [(None, None)]               0         []                            
                                                                                                  
 Encoder_left (Encoder)      (None, None, 64)             1088      ['Input_left[0][0]']          
                                                                                                  
 Encoder_right (Encoder)     (None, None, 64)             1088      ['Input_right[0][0]']         
                                                                                                  
 AE_enc_enc_right (AE_Encod  (None, None, 8)              520       ['Encoder_right[0][0]']       
 er)                                                                                              
                                                                                                  
 AE_enc_enc_left (AE_Encode  (None, None, 8)              520       ['Encoder_left[0][0]']        
 r)                                                                                               
                                                                                                  
 FRAE_enc_right (FRAE)       (2, None, 8)                 1760      ['AE_enc_enc_right[0][0]']    
                                                                                                  
 FRAE_enc_left (FRAE)        (2, None, 8)                 1760      ['AE_enc_enc_left[0][0]']     
                                                                                                  
 AE_enc_dec_right (AE_Decod  (2, None, 64)                576       ['FRAE_enc_right[0][0]']      
 er)                                                                                              
                                                                                                  
 AE_enc_dec_left (AE_Decode  (2, None, 64)                576       ['FRAE_enc_left[0][0]']       
 r)                                                                                               
                                                                                                  
 Attention_layer_2 (Multipl  (2, None, 64)                0         ['Encoder_left[0][0]',        
 y)                                                                  'AE_enc_dec_right[0][0]',    
                                                                     'Encoder_right[0][0]',       
                                                                     'AE_enc_dec_left[0][0]']     
                                                                                                  
 TCN_left (TCN)              (2, 7999, 512)               204480    ['Attention_layer_2[0][0]']   
                                                                                                  
 TCN_right (TCN)             (2, 7999, 512)               204480    ['Attention_layer_2[1][0]']   
                                                                                                  
 AE_tcn_enc_right (AE_Encod  (2, 7999, 8)                 4104      ['TCN_right[0][0]']           
 er)                                                                                              
                                                                                                  
 AE_tcn_enc_left (AE_Encode  (2, 7999, 8)                 4104      ['TCN_left[0][0]']            
 r)                                                                                               
                                                                                                  
 FRAE_tcn_right (FRAE)       (2, None, 8)                 5841      ['AE_tcn_enc_right[0][0]']    
                                                                                                  
 FRAE_tcn_left (FRAE)        (2, None, 8)                 5841      ['AE_tcn_enc_left[0][0]']     
                                                                                                  
 AE_tcn_dec_right (AE_Decod  (2, None, 512)               4608      ['FRAE_tcn_right[0][0]']      
 er)                                                                                              
                                                                                                  
 AE_tcn_dec_left (AE_Decode  (2, None, 512)               4608      ['FRAE_tcn_left[0][0]']       
 r)                                                                                               
                                                                                                  
 Attention_layer_1 (Multipl  (2, 7999, 512)               0         ['TCN_left[0][0]',            
 y)                                                                  'AE_tcn_dec_right[0][0]',    
                                                                     'AE_tcn_dec_left[0][0]',     
                                                                     'TCN_right[0][0]']           
                                                                                                  
 Masker_left (Masker)        (2, 7999, 64)                33344     ['Attention_layer_1[0][0]',   
                                                                     'Encoder_left[0][0]']        
                                                                                                  
 Masker_right (Masker)       (2, 7999, 64)                33344     ['Attention_layer_1[1][0]',   
                                                                     'Encoder_right[0][0]']       
                                                                                                  
 Decoder_left (Decoder)      (2, 64000)                   1024      ['Masker_left[0][0]']         
                                                                                                  
 Decoder_right (Decoder)     (2, 64000)                   1024      ['Masker_right[0][0]']        
                                                                                                 

Basically it is the model of a colleague which I am extending through additional submodels. As you can see, the model is rather small with not even a million of parameters. As mentioned in other threads by me, due to the FRAE model, which is a recurrent model, training on the GPU is rather slow, so I am trying to increase my batch size to improve the training speed.

However, I quickly run out of GPU memory for some reason if I try to train the entire model as listed with the fit() method. I got 24 GB of GPU memory (our cluster even goes up to 48 GB), yet I cannot train with a batch size above 16-32. According to computations of GPU RAM requirements with respect to model size and data size, this seems to be by far enough to train with batch sizes > 128. Yet, when I set the batch size accordingly, I immediately run out of memory for some reason.

I tried to set allow_memory_growth to true via

os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
physical_devices = tf.config.list_physical_devices("/GPU:0")
for device in physical_devices:
    tf.config.experimental.set_memory_growth(device, True)

but this also did not change anything. The data I am using consists of audio files with 64000 samples per file which should not be overkill.

I am wondering, why my model consumes so much memory. Does anyone have an idea what might cause this? The data I am using is stored in TFRecords. I have suspicions that this might be somehow cause the out of memory issue?! Or is it the multiplication going on? That is where the out of memory error is usually raised. I am at wits’ end.
I cannot believe that this model has to require so much data, seeing that you can train way larger models and similar data with about equal amounts of GPU Ram.

The high GPU memory consumption during training is primarily due to the extensive intermediate tensor creation by complex layers like TCN and attention mechanisms, compounded by the large batch sizes. Additionally, the recurrent nature of the FRAE model and potentially inefficient data pipeline from using TFRecords could be exacerbating the memory usage.

Cheers,

Tim

Thanks, is there anything I could do to change that aside of reducing batch size (and the model; but I have to work with the current setup)

In what way could the TFRecords be inefficient? I am wary of them, but did not find clear statements regarding their memory requirements.
My fetch method in the Datagenerator looks like this

  def fetch(self):
        dataset = tf.data.TFRecordDataset(self.tfr).map(self._decode,
                                                        num_parallel_calls=None) # was  num_parallel_calls=tf.data.experimental.AUTOTUNE
        if self.mode == "train":
            dataset = dataset.shuffle(2000, reshuffle_each_iteration=False) 
            train_dataset = dataset.batch(self.batch_size, drop_remainder=True) 
            train_dataset = train_dataset.prefetch(1) # was tf.data.experimental.AUTOTUNE
            return train_dataset

I tried reducing the buffer size in the shuffle method, but that did not change anything.