Implement MultiHeadAttention() into an simple Model


I am trying to implement the MultiHeadAttention layer in a small model.

I would like to realize the equivalent of a self attention
with this layer in a model similar to this one:

inp = tf.keras.layers.Input((10,64))
layer = tf.keras.layers.AdditiveAttention()([inp,inp])
model = tf.keras.Model(inputs=inp, outputs=layer)

but replacing the AdditiveAttention by MultiHeadAttention like this:

inp = tf.keras.layers.Input((10,64))
layer = tf.keras.layers.MultiHeadAttention(num_heads = 2, key_dim = 32)([inp,inp,inp])
model = tf.keras.Model(inputs=inp, outputs=layer)

However by doing this I get the error “call() missing 1 required positional argument: ‘value’”.

I have done some research on how the MultiHeadAttention layer works, however I am not sure what the key_dim and value_dim parameters are.
I would have thought that key_dims was used to change the output shape but when you change the value in the example found at “tf.keras.layers.MultiHeadAttention  |  TensorFlow Core v2.9.1”, this is not the case.

layer = tf.keras.layers.MultiHeadAttention(num_heads=2, key_dim=32)
target = tf.keras.Input(shape=[8, 16])
source = tf.keras.Input(shape=[4, 16])
output_tensor, weights = layer(target, source,

Thanks in advance

1 Like