Tf.keras.Model and model class adaptation

Hello everyone. I’m new to the tensorflow library and keras package so I am having some trouble in making adjustment to a 3D-UNET originally defined using Tf.keras.Model to an adaptation of another model defined using model class declaration.

For starter, I have a 3D dataset with images had been already divided into equivalent patch size of 64x64x64. Then I create an object tf.data.Dataset as the input for the model (as instructed by the original code), you can prefer the Dataset object and the size of the input data in attached image:

Then, the dataset is fed into a 3D-UNET model originally defined below:

# DOWNWARD PATH (encoder)
    conv1 = tf.keras.layers.Conv3D(conv_filters[0], 3, activation='relu', padding='same', data_format="channels_last")(inputs_)
    conv1 = tf.keras.layers.Conv3D(conv_filters[0], 3, activation='relu', padding='same')(conv1)
    bn1 = tf.keras.layers.BatchNormalization()(conv1)
    pool1 = tf.keras.layers.MaxPooling3D(pool_size=(2, 2, 2))(bn1)
    conv2 = tf.keras.layers.Conv3D(conv_filters[1], 3, activation='relu', padding='same')(pool1)
    conv2 = tf.keras.layers.Conv3D(conv_filters[1], 3, activation='relu', padding='same')(conv2)
    bn2 = tf.keras.layers.BatchNormalization()(conv2)
    pool2 = tf.keras.layers.MaxPooling3D(pool_size=(2, 2, 2))(bn2)
    conv3 = tf.keras.layers.Conv3D(conv_filters[2], 3, activation='relu', padding='same')(pool2)
    conv3 = tf.keras.layers.Conv3D(conv_filters[2], 3, activation='relu', padding='same')(conv3)
    bn3 = tf.keras.layers.BatchNormalization()(conv3)
    pool3 = tf.keras.layers.MaxPooling3D(pool_size=(2, 2, 2))(bn3)

    conv4 = tf.keras.layers.Conv3D(conv_filters[3], 3, activation='relu', padding='same')(pool3)
    conv4 = tf.keras.layers.Conv3D(conv_filters[3], 3, activation='relu', padding='same')(conv4)
    bn4 = tf.keras.layers.BatchNormalization()(conv4)

    # UPWARD PATH (decoder)
    up5 = tf.keras.layers.Conv3D(conv_filters[2], 2, activation='relu', padding='same')(tf.keras.layers.UpSampling3D(size=(2, 2, 2))(bn4))
    merge5 = tf.keras.layers.Concatenate(axis=-1)([bn3, up5])
    conv5 = tf.keras.layers.Conv3D(conv_filters[2], 3, activation='relu', padding='same')(merge5)
    conv5 = tf.keras.layers.Conv3D(conv_filters[2], 3, activation='relu', padding='same')(conv5)
    bn5 = tf.keras.layers.BatchNormalization()(conv5)
    up6 = tf.keras.layers.Conv3D(conv_filters[1], 2, activation='relu', padding='same')(tf.keras.layers.UpSampling3D(size=(2, 2, 2))(bn5))
    merge6 = tf.keras.layers.Concatenate(axis=-1)([bn2, up6])
    conv6 = tf.keras.layers.Conv3D(conv_filters[1], 3, activation='relu', padding='same')(merge6)
    conv6 = tf.keras.layers.Conv3D(conv_filters[1], 3, activation='relu', padding='same')(conv6)
    bn6 = tf.keras.layers.BatchNormalization()(conv6)
    up7 = tf.keras.layers.Conv3D(conv_filters[0], 2, activation='relu', padding='same')(tf.keras.layers.UpSampling3D(size=(2, 2, 2))(bn6))
    merge7 = tf.keras.layers.Concatenate(axis=-1)([bn1, up7])
    conv7 = tf.keras.layers.Conv3D(conv_filters[0], 3, activation='relu', padding='same')(merge7)
    conv7 = tf.keras.layers.Conv3D(conv_filters[0], 3, activation='relu', padding='same')(conv7)
    bn7 = tf.keras.layers.BatchNormalization()(conv7)
    outputs_ = tf.keras.layers.Conv3D(1, 1, activation='sigmoid')(bn7)

    model = tf.keras.Model(inputs=inputs_, outputs=outputs_)
    model.compile(optimizer=tf.keras.optimizers.legacy.Adam(lr=learning_rate), loss=bce_dice_loss(lambda_loss), metrics=[dice_coeff, "binary_crossentropy"])
    

I would like to make an adaptation of the UNET with transformer model, for example, defined using class declation to the above code, and this it what I get thus far:

class SingleDeconv3DBlock(tf.keras.layers.Layer):

        def __init__(self,filters):
            super(SingleDeconv3DBlock, self).__init__()
            self.block = tf.keras.layers.Conv3DTranspose(filters= filters, 
                                                        kernel_size=2, strides=2, 
                                                        padding="valid", 
                                                        output_padding=None)
                                                        

        def call(self, inputs):        
            return self.block(inputs)

    class SingleConv3DBlock(tf.keras.layers.Layer):

        def __init__(self, filters,kernel_size):
            super(SingleConv3DBlock, self).__init__()
            self.kernel=kernel_size
            self.res = tuple(map(lambda i: (i - 1)//2, self.kernel))
            self.block = tf.keras.layers.Conv3D(filters= filters, 
                                                kernel_size=kernel_size, 
                                                strides=1, 
                                                padding='same')

        def call(self, inputs):
            return self.block(inputs)
        
    class Conv3DBlock(tf.keras.layers.Layer):

        def __init__(self, filters,kernel_size=(3,3,3)):
            super(Conv3DBlock, self).__init__()
            self.a= tf.keras.Sequential([
                                        SingleConv3DBlock(filters,kernel_size=kernel_size),
                                        tf.keras.layers.BatchNormalization(),
                                        tf.keras.layers.Activation('relu')
            ])
            

        def call(self, inputs):
            return self.a(inputs)
        
    class Deconv3DBlock(tf.keras.layers.Layer):

        def __init__(self, filters,kernel_size=(3,3,3)):
            super(Deconv3DBlock, self).__init__()
            self.a= tf.keras.Sequential([
                                        SingleDeconv3DBlock(filters=filters),
                                        SingleConv3DBlock(filters=filters,kernel_size=kernel_size),
                                        tf.keras.layers.BatchNormalization(),
                                        tf.keras.layers.Activation('relu')
            ])
    
        def call(self, inputs):
            return self.a(inputs)
        
    class SelfAttention(tf.keras.layers.Layer):

        def __init__(self, num_heads,embed_dim,dropout):
            super(SelfAttention, self).__init__()

            self.num_attention_heads = num_heads
            self.attention_head_size = int(embed_dim / num_heads)
            self.all_head_size = self.num_attention_heads * self.attention_head_size

            self.query=tf.keras.layers.Dense(self.all_head_size)
            self.key = tf.keras.layers.Dense(self.all_head_size)
            self.value = tf.keras.layers.Dense(self.all_head_size)                

            self.out=tf.keras.layers.Dense(embed_dim)
            self.attn_dropout=tf.keras.layers.Dropout(dropout)
            self.proj_dropout=tf.keras.layers.Dropout(dropout)

            self.softmax=tf.keras.layers.Softmax()

            self.vis=False

        def transpose_for_scores(self,x):
            new_x_shape=list(x.shape[:-1] + (self.num_attention_heads, self.attention_head_size))
            new_x_shape[0] = -1
            y = tf.reshape(x, new_x_shape)
            return tf.transpose(y,perm=[0,2,1,3])

        def call(self, hidden_states):
            mixed_query_layer = self.query(hidden_states)
            mixed_key_layer = self.key(hidden_states)
            mixed_value_layer = self.value(hidden_states)

            query_layer = self.transpose_for_scores(mixed_query_layer)
            key_layer = self.transpose_for_scores(mixed_key_layer)
            value_layer = self.transpose_for_scores(mixed_value_layer)  
            attention_scores= query_layer @ tf.transpose(key_layer,perm=[0,1,3,2])
            attention_scores= attention_scores/math.sqrt(self.attention_head_size)
            attention_probs=self.softmax(attention_scores)
            weights = attention_probs if self.vis else None
            attention_probs = self.attn_dropout(attention_probs)

            context_layer= attention_probs @ value_layer
            context_layer=tf.transpose( context_layer, perm=[0,2,1,3])
            new_context_layer_shape = list(context_layer.shape[:-2] + (self.all_head_size,))
            new_context_layer_shape[0]= -1
            context_layer = tf.reshape(context_layer,new_context_layer_shape)
            attention_output = self.out(context_layer)
            attention_output = self.proj_dropout(attention_output)
            return attention_output, weights
        
    class Mlp(tf.keras.layers.Layer):

        def __init__(self, output_features, drop=0.):
            super(Mlp, self).__init__()
            self.a=tf.keras.layers.Dense(units=output_features,activation=tf.nn.gelu)
            self.b=tf.keras.layers.Dropout(drop)

        def call(self, inputs):
            x=self.a(inputs)
            return self.b(x)

    class PositionwiseFeedForward(tf.keras.layers.Layer):

        def __init__(self, d_model=768,d_ff=2048, dropout=0.1):
            super(PositionwiseFeedForward, self).__init__()
            self.a=tf.keras.layers.Dense(units=d_ff)
            self.b=tf.keras.layers.Dense(units=d_model)
            self.c=tf.keras.layers.Dropout(dropout)

        def call(self, inputs):
            return self.b(self.c(tf.nn.relu(self.a(inputs))))

    ##embeddings, projection_dim=embed_dim
    class PatchEmbedding(tf.keras.layers.Layer): 
        def __init__(self ,  cube_size, patch_size , embed_dim):
                super(PatchEmbedding, self).__init__()
                self.num_of_patches=int((cube_size[0]*cube_size[1]*cube_size[2])/(patch_size*patch_size*patch_size))
                self.patch_size=patch_size
                self.size = patch_size
                self.embed_dim = embed_dim

                self.projection = tf.keras.layers.Dense(embed_dim)

                self.clsToken = tf.Variable(tf.keras.initializers.GlorotNormal()(shape=(1 , 512 , embed_dim)) , trainable=True)

                self.positionalEmbedding = tf.keras.layers.Embedding(self.num_of_patches , embed_dim)
                self.patches=None
                self.lyer = tf.keras.layers.Conv3D(filters= self.embed_dim,kernel_size=self.patch_size, strides=self.patch_size,padding='valid')
                #embedding - basically is adding numerical embedding to the layer along with an extra dim  
            
        def call(self , inputs):
                patches =self.lyer(inputs)
                patches = tf.reshape(patches , (tf.shape(inputs)[0] , -1 , self.size * self.size * 3))
                patches = self.projection(patches)
                positions = tf.range(0 , self.num_of_patches , 1)[tf.newaxis , ...]
                positionalEmbedding = self.positionalEmbedding(positions)
                patches = patches + positionalEmbedding

                return patches, positionalEmbedding
        
    ##transformerblock
    class TransformerLayer(tf.keras.layers.Layer):
        def __init__(self ,  embed_dim, num_heads ,dropout, cube_size, patch_size):
            super(TransformerLayer,self).__init__()

            self.attention_norm = tf.keras.layers.LayerNormalization(epsilon=1e-6)

            self.mlp_norm = tf.keras.layers.LayerNormalization(epsilon=1e-6)

        #embed_dim/no-of_heads
            self.mlp_dim = int((cube_size[0] * cube_size[1] * cube_size[2]) / (patch_size * patch_size * patch_size))
            
            self.mlp = PositionwiseFeedForward(embed_dim,2048)
            self.attn = SelfAttention(num_heads, embed_dim, dropout)
        
        def call(self ,x  , training=True):
            h=x
            x=self.attention_norm(x)
            x,weights= self.attn(x)
            x=x+h
            h=x

            x = self.mlp_norm(x)
            x = self.mlp(x)

            x = x + h

            return x, weights

    class TransformerEncoder(tf.keras.layers.Layer):
        def __init__(self ,embed_dim , num_heads,cube_size, patch_size , num_layers=12 , dropout=0.1,extract_layers=[3,6,9,12]):
            super(TransformerEncoder,self).__init__()
        #  embed_dim, num_heads ,dropout, cube_size, patch_size
            self.embeddings = PatchEmbedding(cube_size,patch_size, embed_dim)
            self.extract_layers =extract_layers
            self.encoders = [TransformerLayer(embed_dim, num_heads,dropout, cube_size, patch_size) for _ in range(num_layers)]
        
        def call(self , inputs , training=True):
            extract_layers = []
            x = inputs
            x,_=self.embeddings(x)
            
            for depth,layer in enumerate(self.encoders):
                x,_= layer(x , training=training)
                if depth + 1 in self.extract_layers:
                            extract_layers.append(x)
            
            return extract_layers
        
        def get_config(self):

            config = super().get_config().copy()
            config.update({
                'embed_dim': self.embed_dim,
                'num_heads': self.num_heads,
                'img_shape': self.img_shape,
                'patch_size': self.patch_size,
                'num_layers': self.num_layers,
                'dropout': self.dropout,
                'ext_layers': self.ext_layers
        })  
            return config


    input_dim = 3
    output_dim = 3
    embed_dim = 768
    img_shape = (64,64,64)
    patch_size = 16
    num_heads = 12
    dropout = 0.1
    num_layers = 12
    ext_layers = [3, 6, 9, 12]
        
    patch_dim = [int(x / patch_size) for x in img_shape]
    
    # Define Transfomer layers
    transformer = TransformerEncoder(
        embed_dim,
        num_heads,
        img_shape,
        patch_size,
        num_layers,
        dropout,
        ext_layers
        )
    
    # U-Net Decoder
    decoder0 = tf.keras.Sequential([
        Conv3DBlock(32, (3,3,3)),
        Conv3DBlock(64, (3,3,3))]
        )

    decoder3 = tf.keras.Sequential([
        Deconv3DBlock(512),
        Deconv3DBlock(256),
        Deconv3DBlock(128)]
        )

    decoder6 = tf.keras.Sequential([
        Deconv3DBlock(512),
        Deconv3DBlock(256)]
        )

    decoder9 = Deconv3DBlock(512)

    decoder12_upsampler = SingleDeconv3DBlock(512)

    decoder9_upsampler = tf.keras.Sequential([
        Conv3DBlock(512),
        Conv3DBlock(512),
        Conv3DBlock(512),
        SingleDeconv3DBlock(256)]
        )

    decoder6_upsampler = tf.keras.Sequential([
            Conv3DBlock(256),
            Conv3DBlock(256),
            SingleDeconv3DBlock(128)]
        )

    decoder3_upsampler = tf.keras.Sequential(
            [Conv3DBlock(128),
            Conv3DBlock(128),
            SingleDeconv3DBlock(64)]
        )

    decoder0_header = tf.keras.Sequential(
            [Conv3DBlock(64),
            Conv3DBlock(64),
            SingleConv3DBlock(output_dim, (1,1,1))]
        ) 

    
    z = transformer(inputs_)
    z0, z3, z6, z9, z12 = inputs_, z[0],z[1],z[2],z[3]
    z3 = tf.reshape(tf.transpose(z3,perm=[0,2,1]),[-1,  *patch_dim,embed_dim])
    z6 = tf.reshape(tf.transpose(z6,perm=[0,2,1]),[-1,  *patch_dim,embed_dim])
    z9 = tf.reshape(tf.transpose(z9,perm=[0,2,1]),[-1,  *patch_dim,embed_dim])
    z12 = tf.reshape(tf.transpose(z12,perm=[0,2,1]),[-1,  *patch_dim,embed_dim])
    z12 = decoder12_upsampler(z12)
    z9 = decoder9(z9)
    z9 = decoder9_upsampler(tf.concat([z9, z12], 4))
    z6 = decoder6(z6)
    z6 = decoder6_upsampler(tf.concat([z6, z9], 4))
    z3 = decoder3(z3)
    z3 = decoder3_upsampler(tf.concat([z3, z6], 4))
    z0 = decoder0(z0)
    outputs_ = decoder0_header(tf.concat([z0, z3], 4))
    
    model = tf.keras.Model(inputs = inputs_, outputs = outputs_)    
    model.compile(optimizer=tf.keras.optimizers.Adam(lr=learning_rate), loss = bce_dice_loss(lambda_loss), metrics=[dice_coeff, "binary_crossentropy"])

However, whenever I try to run the code, I will always get an error at “model.fit” command, which gives the “InvalidArgumentError: Graph execution error:”. As I understand, it is due to the size mismatch between the dataset and the expected input of the model. But, I don’t really understand where it goes wrong as I am not familiar with the 2 model defining method using keras.

Please forgive me if my topic is inexperience or demanding. I am open for any suggestion.
Thank you.