Transformer from scratch, subclassing with keras. Gradients does not exist for these variables

Hello everyone,

I am having troubles debugging my code. I have recently switched from Pytorch to Tensorflow and I am trying to reproduce the model from the paper Attention Is All You need from scratch in tf.Keras. I am not using the predefined tf.keras.layers.MultiHeadAttention. Instead I have created my own version which is modelled after the graph shown in the paper. The code is quite lengthy but I will show it later first I want to specify the problem. I am getting the following:

WARNING:tensorflow:Gradients do not exist for variables 
['transformer_1/encoder_1/positional_embedding_layer_3/embedding_3/embeddings:0','transformer_1/encoder_1/encoder_layer_4/sub_module_20/multi_headed_attention_12/attention_12/dense_65/kernel:0',
'transformer_1/encoder_1/encoder_layer_4/sub_module_20/multi_headed_attention_12/attention_12/dense_65/bias:0',
.......

When looking online I found 2 possible reasons why that might happen:

  1. I am transforming some variables to numpy arrays and thus I drop its gradient.
  2. I am not properly connecting the graph.

I looked carefully and I am sure that I am not converting anything from the graph to numpy so the problem must be in 2 (or something else that I dont know yet).

I will show you the code and provide a brief explanation for some choices (Quite big but bear with me).

  1. In the paper they say:
    We employ a residual connection [10] around each of the two sub-layers, followed by layer normalization [1]. That is, the output of each sub-layer is LayerNorm(x + Sublayer(x)), where Sublayer(x). So I modelled that as Layer.
class SubModule(tf.keras.layers.Layer):
    def __init__(self,Sublayer):
        """_summary_

        Args:
            Sublayer (keras.Layer): A submodule in the encoder or decoder
        """
        super(SubModule, self).__init__()
        self.Sublayer = Sublayer
        self.LayerNorm = tf.keras.layers.LayerNormalization()
        self.Add = tf.keras.layers.Add()

    def call(self,x,k=None,v=None):
        if k is not None and v is not None:
            return self.LayerNorm(self.Add([x,self.Sublayer(x,k,v)]))
        else:
            return self.LayerNorm(self.Add([x,self.Sublayer(x)]))

The reason I have an if statement there is because when I implement the cross attention I need to get inputs from the encoder as well.

  1. My implementation of the MultiHeadAttention: It follows Figure 2 in the paper.
class Attention(tf.keras.layers.Layer):
    def __init__(self,dims,use_mask=False):
        """_summary_

        Args:
            dims (int): The number of dimensions to attend to
            mask (tensor, optional): Mask parts of the input
        """
        super(Attention, self).__init__()
        self.d_k = tf.keras.layers.Dense(dims)
        self.d_q = tf.keras.layers.Dense(dims)
        self.d_v = tf.keras.layers.Dense(dims)

        self.dims = dims
        self.use_mask = use_mask
        self.last_attention_score = None

    def ScaledDotProductAttention(self,Q,V,K):
        score = tf.matmul(Q,K,transpose_b=True)/tf.math.sqrt(self.dims/1)
        if self.use_mask:
            mask = self.compute_causal_mask(score.shape[-1])
            score = score + mask
        score = keras.activations.softmax(score)
        return tf.matmul(score,V)
    
    def compute_causal_mask(self,score_dims):
        mask = np.triu(np.ones((score_dims,score_dims))*-np.inf,1)
        return tf.cast(mask,dtype=tf.float32)

    def call(self,q,k,v):
        Q,K,V = self.d_q(q),self.d_k(q),self.d_v(q)
        return self.ScaledDotProductAttention(Q,K,V)


class MultiHeadedAttention(tf.keras.layers.Layer):
    def __init__(self,d_model, h, use_mask=False):
        """_summary_

        Args:
            d_model (int): length of the embedding
            h (int): number of heads
            mask (tensor, optional): Mask parts of the input
        """
        super(MultiHeadedAttention,self).__init__()
        self.h = h
        self.heads = Attention(d_model/h, use_mask)
        self.reshape = tf.keras.layers.Reshape(target_shape=(-1,h,d_model//h))
        self.reverse = tf.keras.layers.Reshape(target_shape=(-1,d_model))
        self.WO = tf.keras.layers.Dense(d_model)

    def split(self,x):
        """_summary_

        Args:
            x (tensor): a tensor of shape (batch_size,seq_length,d_model)

        Returns:
            (tensor): a tensor of shape (batch_size,h,seq_length,d_model//h)
        """
        return tf.transpose(self.reshape(x),perm=[0,2,1,3])

    def concat(self,x):
        """_summary_

        Args:
            x (tensor): a tensor of shape (batch_size,h,seq_length,d_model//h)

        Returns:
            (tensor): a tensor of shape (batch_size,seq_length,d_model)
        """

        return self.reverse(tf.transpose(x,perm=[0,2,1,3]))

    def call(self,q,k,v):
        q = self.split(q)
        k = self.split(k)
        v = self.split(v)
        x = self.heads(q,k,v) # tensor of shape=(batch_size,seq_len,d_model//h)
        x = self.concat(x)
        return self.WO(x)
  1. The position-wise feed-forward network is simple and looks as follows:
class PositionWiseFeedForward(tf.keras.layers.Layer):
    def __init__(self,d_model,d_ff):
        super(PositionWiseFeedForward,self).__init__()
        self.l1 = tf.keras.layers.Dense(d_ff,activation="relu")
        self.l2 = tf.keras.layers.Dense(d_model)
    
    def call(self,x):
        return self.l2(self.l1(x))
  1. The Encoder is a stack of N=6 encoderLayers
class EncoderLayer(tf.keras.layers.Layer):
    def __init__(self,d_model,h,dff):
        super(EncoderLayer,self).__init__()
        self.fsm = SubModule(MultiHeadedAttention(d_model,h))
        self.ssm = SubModule(PositionWiseFeedForward(d_model,dff))
        
    def call(self,x):
        x = self.fsm(x,x,x)
        x = self.ssm(x)
        return x

class Encoder(tf.keras.layers.Layer):
    def __init__(self,N,d_model,max_seq_len,D,h,dff):
        super(Encoder,self).__init__()
        self.embedding = PositionalEmbeddingLayer(
            d_model = d_model,
            max_seq_len=max_seq_len,
            D=D)
        self.encoderStack = [
            EncoderLayer(
                d_model=d_model,
                h=h,
                dff=dff)
            for _ in range(N)]
                
    def call(self,x):
        posX = self.embedding(x)

        for i in range(len(self.encoderStack)):
            posX = self.encoderStack[i](posX)
        return posX
  1. To save space assume that the decoder is created using those blocks as well. The final transformer looks as follows.
class Transformer(keras.Model):
    def __init__(self,N, d_model, h, dff, max_seq_len, in_D, out_D):
        """_summary_

        Args:
            N (int): Number of layers in the encoder and decoder
            d_model (int): Dimensions of the embedding
            h (int): Number of heads
            dff (int): Dimensions of the positionwise feedforward network
            max_seq_len (int): Length of the sequence
            in_D (int): Size of the input vocabulary
            out_D (int): Size of the output vocabulary
        """
        super().__init__()
        self.encoder = Encoder(
            N=N,
            d_model=d_model,
            max_seq_len=max_seq_len,
            D=in_D,
            h=h,
            dff=dff)

        self.decoder = Decoder(
            N=N,
            d_model=d_model,
            max_seq_len=max_seq_len,
            D=out_D,
            h=h,
            dff=dff
        )

    def call(self, inputs):
        context,x = inputs
        z = self.encoder(context)
        y = self.decoder(x,z)

       # Something I saw on the internet to save memory
        try:
            del y._keras_mask
        except AttributeError:
            pass
    
        return y

The positional embedding transforms something of shape (batch,len) to (batch,len,d_model) It is also a keras layer.

Can anyone help me to pinpoint the mistake?
Thank you in advance.