Help needed with TimeDistributed MultiHeadAttention

I am trying to use Multi Head Attention with a Time Distributed layer, but I guet an error:
TypeError: Exception encountered when calling layer “time_distributed” (type TimeDistributed).
call() missing 1 required positional argument: ‘value’

Minimal code to reproduce:

import tensorflow as tf

query_input = tf.keras.Input(shape=(10, 144, 64), dtype='uint8')
value_input = tf.keras.Input(shape=(10, 144, 64), dtype='uint8')

attention_output = tf.keras.layers.TimeDistributed(
    tf.keras.layers.MultiHeadAttention(num_heads=4, key_dim=64, dropout=0.1)
)((query_input, value_input))

Does somebody know how to fix this error? I understand that the problem arises because the MHA Layer does not recieve the second argument given as the value argument, but why? And how to fix that?

Help would be greatly appreciated!


Welcome to the Tensorflow Forum!

As per documentation:

TimeDistributed layer expects the input should be Input tensor of shape (batch, time, …) or nested tensors and each of which has shape (batch, time, …).

In your case two tensors are being passed resulting the error. You can pass the inputs to the MultiHeadAttention layer and pass the instance of MultiHeadAttention layer to the TimeDistributed layer.

Please find the working code below

import tensorflow as tf

query_input = tf.keras.Input(shape=(10, 144, 64))
value_input = tf.keras.Input(shape=(10, 144, 64))

attention_output = tf.keras.layers.MultiHeadAttention(
    num_heads=4, key_dim=64, dropout=0.1)(query_input, value_input)

attention_time = tf.keras.layers.TimeDistributed(tf.keras.layers.Lambda(lambda x: x))(attention_output)

Note: This might not work with dtype=uint8 as there will be mismatch between uint8 and float32 type which MultiHeadAttention expects.

Hey @chunduriv,

thanks for your fast response. The note about the dtype makese sense. That’s an error in my example.

I am still confused about the TimeDistributed layer usage.
As far as I understand, your example only wraps the Identidy/Lambda layer.
This means that the TimeDistribution is not applied to the MHA layer. This means that [1] MHA is applied to q and k of shape 10x144x64, instead of [2] 10 times applying MHA with q and k of shape 144x64.

Did I get that wrong?
My goal is to achieve what I described in [2].