MultiHeadAttention With 2 Attention Axes And An Attention Mask - How to apply mask

I have a small GPT model for text generation. I modified it to work on vectors rather than softmax outputs. So it basically generates a vector for the predicted next token. Normally, output shape of the model would be (None, seq_len, vocab_size). Now, it’s (None, seq_len, vector_size). There’s no problem with that. I can run the model. What I’d like to add is another dimension to the vectors. I want for each token to be represented by two vectors. So the output shape will be (None, seq_len, num_vectors, vector_size). And I want to atten over last to dimensions (num_vectors, vector_size). I can run the model with this structure. However, I am having problems with how to apply attention mask to such an architecture. There arises some problems with shapes within Transformer model. Can you help me build an mask function?

Transformer model:

def causal_attention_mask(batch_size, n_dest, n_src, num_vectors, dtype):
    """
    Mask the upper half of the dot product matrix in self attention.
    This prevents flow of information from future tokens to current token.
    1's in the lower triangle, counting from the lower right corner.
    """
    i = tf.range(n_dest)[:, None]
    j = tf.range(n_src)
    m = i >= j - n_src + n_dest
    mask = tf.cast(m, dtype)
    tf.print(mask)
    mask = tf.reshape(mask, [1, n_dest, n_src])
    mult = tf.concat(
        [tf.expand_dims(batch_size, -1), tf.convert_to_tensor([1, 1])], 0
    )
    return tf.tile(mask, mult)

class TransformerBlock(layers.Layer):
    def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1, **kwargs):
        super().__init__(**kwargs)
        self.att = layers.MultiHeadAttention(num_heads, embed_dim, attention_axes=[1, 2])
        self.ffn = keras.Sequential(
            [
                layers.Dense(ff_dim, activation="relu"),
                layers.Dense(embed_dim),
            ]
        )
        self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = layers.Dropout(rate)
        self.dropout2 = layers.Dropout(rate)

    def call(self, inputs):
        input_shape = tf.shape(inputs)
        batch_size = input_shape[0]
        seq_len = input_shape[2]
        causal_mask = causal_attention_mask(batch_size, seq_len, seq_len, 2, "bool")
        attention_output = self.att(inputs, inputs)
        attention_output = self.dropout1(attention_output)
        out1 = self.layernorm1(inputs + attention_output)
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output)
        return self.layernorm2(out1 + ffn_output)

    def get_weights(self):
        return [self.att.get_weights(), self.ffn.layers[0].get_weights(), self.ffn.layers[1].get_weights()]

    def set_weights(self, weights):
        self.att.set_weights(weights[0])
        self.ffn.layers[0].set_weights(weights[1])
        self.ffn.layers[1].set_weights(weights[2])

Main model:

def create_model():
    inputs = layers.Input(shape=(maxlen,), dtype="float32")
   
    node_embedding_layer = NodeEmbedding(embed_dim, final_embeddings, name='node_embedding_layer')
    x2 = node_embedding_layer(inputs)
    # x2 = tf.transpose(x2, perm=[0, 2, 1, 3])  # (batch_size, vector_count, seq_len, vector_dim)
    transformer_block = TransformerBlock(embed_dim, num_heads, feed_forward_dim, name='node_transformer_layer')
    x2 = transformer_block(x2)

    # Calculate final output node embeddings
    outputs_2 = layers.Dense(node_embedding_dim / 2, name='graph_embedding_output')(x2)
    
    # Concatenate outputs of both branches
    # outputs = layers.Average()([outputs_1, outputs_2])

    outputs = outputs_2

    # Softmax layer
    # softmax_output = layers.Softmax()(outputs)
    
    model = tf.keras.Model(inputs=inputs, outputs=[outputs])
    # loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    model.compile(
        optimizer=keras.optimizers.Adam(0.001),
        loss=[euclidean_distance_loss],
        metrics=['mse']
    )  # No loss and optimization based on word embeddings from transformer block
    return model