How to implement a custom gradient function in a layer?

I want to implement a layer with custom functionality, meaning custom forward and backward computations. It is straight-forward to implement the forward method in keras; simply define the computation inside the call method. However, the backward computation doesn’t seem to be as straight-forward.

To make sure that I understand what I’m doing, I am trying to replicate the behaviour of the Dense layer by writing it from scratch.

Here are some helping code to help us test the layer.

import tensorflow as tf
from tensorflow import keras

def load_mnist_dataset():
    (X_train, Y_train), (X_test, Y_test) = keras.datasets.mnist.load_data()

    # scale images to the [0, 1] range
    X_train = X_train.astype("float32") / 255
    X_test = X_test.astype("float32") / 255

    # reshape dataset to (num_of_samples, height * width)
    X_train = X_train.reshape((X_train.shape[0], 28 * 28))
    X_test = X_test.reshape((X_test.shape[0], 28 * 28))

    # convert class vectors 10bit one-hot encoded values
    Y_train = keras.utils.to_categorical(Y_train, 10)
    Y_test = keras.utils.to_categorical(Y_test, 10)
    
    return (X_train, Y_train), (X_test, Y_test)

def train_model(model, X, Y, batch_size=128, epochs=20):        
    model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
    model.fit(X, Y, batch_size=batch_size, epochs=epochs)
    return model

def model(layer=keras.layers.Dense, activation="relu"):

    # input layer
    input_layer = keras.Input(shape=(28 * 28))

    # hidden layer
    x = layer(512, activation=activation)(input_layer)
    
    # output layer
    output_layer = layer(10, activation=activation)(x)

    model = keras.Model(inputs=input_layer, outputs=output_layer, name="mnist_model")
    return model

Here is where I implement a custom dense layer. The __init__ and build methods are copy-paste from here.

class CustomDense(keras.layers.Layer):
    
    def __init__(self, units=32, activation=None):
        super(CustomDense, self).__init__()
        self.units = units
        self.activation = activation

    def build(self, input_shape):
        self.w = self.add_weight(
            shape=(input_shape[-1], self.units),
            initializer="random_normal",
            trainable=True,
        )
        self.b = self.add_weight(
            shape=(self.units,),
            initializer="random_normal",
            trainable=True
        )

    def call(self, inputs):
        return custom_op(inputs, self.activation, self.w, self.b)

According to this page, I implemented the call method.

@tf.custom_gradient
def custom_op(inputs, activation, weights, biases):
    
    # forward computation
    z = tf.matmul(inputs, weights) + biases
    if activation is not None:
        z = activation(result)
    
    # backward computation
    def grad(upstream):
        inputs_gradient = tf.matmul(upstream, tf.transpose(weights))
        weights_gradient = tf.matmul(tf.transpose(inputs), upstream)
        bias_gradient = upstream
        return inputs_gradient, weights_gradient, bias_gradient
    
    return z, grad

I am facing two issues:

  1. I am not able to pass the activation function in the custom_op function. It fails.
  2. After removing the activation function, the code fails again. I don’t know why. Maybe, the gradients are not being used.

Any assistance would be appreciated.

I wonder if anyone has even seen this post.