RESOURCE_EXHAUSTED when running TimeDistributed on MultiHeadAttention

Hi , i have built a model for classification of time series using a transformer and i am getting an error also when i am reducing all the hyperparameters.
here is my model :
class Transformer_layer(tf.keras.layers.Layer):
def init(self, i_shape, num_transformer_blocks, head_size, num_heads, ff_dim, dropout, mlp_units, mlp_dropout):
super().init()
self.i_shape = i_shape
self.num_transformer_blocks = num_transformer_blocks
self.head_size = head_size
self.num_heads = num_heads
self.ff_dim = ff_dim
self.dropout = dropout
self.mlp_units = mlp_units
self.dense_layers = []
for dim in self.mlp_units:
self.dense_layers.append(tf.keras.layers.Dense(dim, activation=“relu”))

    self.mlp_dropout = mlp_dropout
    self.norm = tf.keras.layers.LayerNormalization(epsilon=1e-6)
    self.attention = tf.keras.layers.MultiHeadAttention(
        key_dim=head_size, num_heads=num_heads, dropout=dropout)
    self.dropout_layer = tf.keras.layers.Dropout(dropout)
    self.conv1 = tf.keras.layers.Conv1D(filters=ff_dim, kernel_size=1, activation="relu")
    self.conv2 = tf.keras.layers.Conv1D(filters=i_shape[-1], kernel_size=1)
    self.mlp_dropout_layer = tf.keras.layers.Dropout(self.mlp_dropout)

def call(self, x):
    # (1024, 1)
    for _ in range(self.num_transformer_blocks):
        x = self.transformer_encoder(x)

    x = tf.keras.layers.GlobalAveragePooling1D(data_format="channels_first")(x)
    for dense_layer in self.dense_layers:
        x = dense_layer(x)
        x = self.mlp_dropout_layer(x)
    # (128)
    return x

def transformer_encoder(self, inputs):
    # Normalization and Attention
    x = self.norm(inputs)
    x = self.attention(x, x)
    x = self.dropout_layer(x)
    res = x + inputs

    # Feed Forward Part
    x = self.norm(res)
    x = self.conv1(x)
    x = self.dropout_layer(x)
    x = self.conv2(x)

    return x + res

def compute_output_shape(self, input_shape):
    input_shape = tf.TensorShape(input_shape) # (None, 1024, 1)
    input_shape = input_shape.with_rank_at_least(2)
    if tf.compat.dimension_value(input_shape[-1]) is None:
        raise ValueError(
            'The innermost dimension of `input_shape` must be defined, '
            'but saw: {}'.format(input_shape))
    return input_shape[:-2].concatenate(self.mlp_units[-1]) # (None, 128)

def build_transformer_model(shape, head_size, num_heads, ff_dim, num_transformer_blocks,
mlp_units, mlp_dropout=0, dropout=0, n_classes=4):
print(“Building Model”, end=“”)
inputs = tf.keras.layers.Input(shape=shape)
x = inputs # size = (batch_size, 1200, 1024, 1) → (patients, 1200 timesstemps, 30s, 1 feature)

transformer_layer = Transformer_layer(shape, num_transformer_blocks, head_size, num_heads, ff_dim, dropout,
                                      mlp_units, mlp_dropout)
x = tf.keras.layers.TimeDistributed(transformer_layer, name=f"Transformer")(x)
#x = transformer_layer(x)
# (batch_size, 1200, 128)

x = tf.keras.layers.LSTM(2, input_shape=x.shape[1:], return_sequences=True)(x) # (batch_size, 100)
output = tf.keras.layers.Dense(n_classes, activation="softmax")(x)  # (batch_size, 1200, 4)

model = tf.keras.Model(inputs=inputs, outputs=output)
print("Model Compiled")
return model

i am running the model with batch_size = 1, head_size=8, num_heads=2, ff_dim=2, num_transformer_blocks=4, mlp_units=[8], mlp_dropout=0.1, dropout=0.25

my gpu is :
NVIDIA-SMI 520.61.05 Driver Version: 520.61.05 CUDA Version: 11.8 |
|-------------------------------±---------------------±---------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA A100 80G… Off | 00000000:3B:00.0 Off | 0 |
| N/A 28C P0 42W / 300W | 4MiB / 81920MiB | 0% Default |
| | | Disabled |

my data is from size (batch_size, 1200, 750, 1)

and i am getting this error :
Node: ‘model/Transformer/while/transformer_layer/multi_head_attention/softmax/Softmax_2’
2 root error(s) found.
(0) RESOURCE_EXHAUSTED: OOM when allocating tensor with shape[1,2,750,750] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc
[[{{node model/Transformer/while/transformer_layer/multi_head_attention/softmax/Softmax_2}}]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn’t available when running in Eager mode.

 [[StatefulPartitionedCall/confusion_matrix/assert_less/Assert/AssertGuard/pivot_f/_2736/_59]]

Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn’t available when running in Eager mode.

(1) RESOURCE_EXHAUSTED: OOM when allocating tensor with shape[1,2,750,750] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc
[[{{node model/Transformer/while/transformer_layer/multi_head_attention/softmax/Softmax_2}}]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn’t available when running in Eager mode.

0 successful operations.
0 derived errors ignored. [Op:__inference_train_function_18151]
None

The RESOURCE_EXHAUSTED or Out Of Memory (OOM) error you’re encountering is due to the TensorFlow model requiring more memory than is available on your GPU. Given the specifications of your model and the NVIDIA A100 GPU, it’s surprising that you’re running into this issue even with a batch size of 1. However, the error message provides a hint: the model is trying to allocate a tensor with shape [1,2,750,750] which is quite large, especially considering the depth of the model and the use of TimeDistributed over a sequence of 1200 timesteps.

Here are some strategies to mitigate this issue:

  1. Reduce Sequence Length or Model Depth: The sequence length of 1200 with a model depth that includes multiple transformer blocks and an LSTM layer is quite demanding. Consider reducing the sequence length if possible, or reducing the number of transformer blocks.
  2. Gradient Checkpointing: This technique involves storing only a subset of intermediate activations during the forward pass and recomputing them during the backward pass, thus reducing memory usage at the cost of additional computation. TensorFlow’s tf.recompute_grad can be used for this purpose, but you’ll need to manually implement it around memory-intensive operations.
  3. Model Simplification: Simplify the model by reducing the number of heads in the MultiHeadAttention layer, the dimensions of the feedforward network (ff_dim), or the number of dense layers. Each of these will reduce the memory footprint.
  4. Mixed Precision Training: Utilize mixed precision training to reduce the memory footprint. The NVIDIA A100 supports mixed precision very efficiently. TensorFlow provides easy-to-use utilities for mixed precision training. You can enable it by setting the policy with tf.keras.mixed_precision.set_global_policy('mixed_float16').
  5. Use tf.data and Optimize Input Pipeline: Ensure that your input pipeline is not the bottleneck and is efficiently batching the data. Utilize tf.data API to prefetch and batch your data efficiently. This won’t reduce the memory usage per se but will ensure that data loading is not contributing to the memory issue.
  6. Model Parallelism: If feasible, consider splitting the model across multiple GPUs. This is a more complex solution and might require significant changes to your model architecture and training script.
  7. Check for Memory Leaks: Ensure that there are no memory leaks in your code. This can happen if tensors are inadvertently kept alive, leading to gradual memory buildup.

Implementing these strategies requires a balance between model complexity and resource constraints. Start with the simpler approaches like reducing model complexity and enabling mixed precision, and progressively move to more complex solutions if necessary.