The `tf.signal.irfft` function leads to a continuous growth in GPU memory

When I use tf.signal.irfft in TensorFlow 2.9, I experience a continuous increase in GPU memory usage, and the GPU I’m using is an NVIDIA GeForce 4090. However, the same code doesn’t exhibit the mentioned issue when running on TensorFlow 2.4 with a GPU 2080.

Hi @lucky_hu

Welcome to the TensorFlow Forum!

Could you please retry with the latest TensorFlow version 2.13 or 2.14 and let us know if the issue still persists. Thank you.

I cannot use TensorFlow version 2.14. When I use tf.keras.layers.LayerNormalization(axis=[-1, -2]) , the program shows an error: ‘cuDNN launch failure: input shape ([1, 1992, 257, 1]).’ The original input shape is [8, 249, 257, 1].

I tried 2.14 and still couldn’t solve my problem

Could you please share the reproducible code to replicate the error and understand the issue? Thank you.

class stft(keras.layers.Layer):

def __init__(self, block_len=512,mode= "mag_pha"):
    super().__init__()
    self.block_len = block_len
    self.win = tf.signal.hann_window(block_len)
    self.mode = mode

def call(self, x):
    frames = tf.signal.frame( x, self.block_len, self.block_len//2)
    frames = self.win * frames
    stft_dat = tf.signal.rfft(frames)
    if self.mode == "mag_pha":
        mag = tf.math.abs(stft_dat)
        phase = tf.math.angle(stft_dat)
        output_list = [mag, phase]
    else:
        real = tf.math.real(stft_dat)
        imag = tf.math.imag(stft_dat)
        output_list = [real, imag]
    return output_list

class ifftLayer(keras.layers.Layer):

def __init__(self, block_len=512, mode="mag_pha"):
    super().__init__()
    self.block_len = block_len
    self.win = tf.signal.hann_window(block_len)
    self.mode = mode

def call(self, x):
    if self.mode == "mag_pha":
        # calculating the complex representation
        s1_stft = (tf.cast(x[0], tf.complex64)) * tf.exp((1j * tf.cast(x[1], tf.complex64)))
    else:
        s1_stft = tf.cast(x[0], tf.complex64) + 1j * tf.cast(x[1], tf.complex64)

    irfft = tf.signal.irfft(s1_stft) * self.win
    return irfft

def Spon_DRC_NET_Block_Concat_T(batch_size=256, block_len=512, gain=None, num=16, epoch=2, input_norm=‘iLN’):
# input layer for time signal
time_data = tf.keras.layers.Input(batch_shape=(batch_size, None))
mag ,_ = stft()(time_data)
# mag, _ = tf.keras.layers.Lambda(stft, arguments={‘mode’: ‘mag_pha’})(time_data)
mag = tf.keras.layers.Lambda(reshape, arguments={‘axis’: [batch_size, -1, block_len // 2 + 1, 1]})(mag)
‘’‘encoder’‘’
if input_norm == “iLN”:
input_comples_spec = tf.keras.layers.LayerNormalization(axis=[-1, -2])(mag)
elif input_norm == “BN”:
input_comples_spec = tf.keras.layers.BatchNormalization()(mag)

# causl padding [1,0] [0,2]
'''encode'''
conv_1 = tf.keras.layers.Conv2D(num, (2, 5), (1, 2), padding="valid", name="CONV")(input_comples_spec)
bn_1 = tf.keras.layers.LayerNormalization(axis=[-1, -2])(conv_1)
out_1 = PReLU(shared_axes=[1, 2])(bn_1)

out_1 = DRC_Block_T(numUnits=num, batch_size=batch_size, L=-1, width=127, channel=num, epoch=epoch, causal=True)(
    out_1)

c_out_1 = tf.keras.layers.Conv2D(num, (1, 1), (1, 1), padding='same')(out_1)
skipcon_5 = tf.keras.layers.concatenate([c_out_1, out_1])
deconv_5 = tf.keras.layers.Conv2DTranspose(1, (2, 5), (1, 2), padding="valid", use_bias=False, name="DCONV4")(
    skipcon_5)

output_mask = tf.keras.activations.sigmoid(deconv_5)

enh_spec = tf.keras.layers.Lambda(mk_mag_mask, arguments={'gain': gain})([time_data, output_mask])

enh_frame = ifftLayer()(enh_spec)

enh_time = tf.keras.layers.Lambda(overlapAddLayer, name='enhanced_time')(enh_frame)
model = tf.keras.models.Model(time_data, enh_time)
model.summary()
return model