How to create a keras layer with a custom gradient *and learnable parameters* in TF2.0?

Hi, this is a similar question to: python - How to create a keras layer with a custom gradient in TF2.0? - Stack Overflow

Only, I would like to introduce a learnable parameter into the custom layer that I am training.

Here’s a toy example of my current approach here:

# Method for calculation custom gradient 
@tf.custom_gradient
def scaler(x, s):
  def grad(upstream):
    dy_dx = s
    dy_ds = x
    return dy_dx, dy_ds
  return x * s, grad

# Keras Layer with trainable parameter
class TestLayer(tf.keras.layers.Layer):
  def build(self, input_shape):
    self.scale = self.add_weight("scale", 
                                 shape=[1,], 
                                 initializer=tf.keras.initializers.Constant(value=2.0), 
                                 trainable=True)
  def call(self, inputs):
    return scaler(inputs, self.scale)

# Creates Keras Model that uses the layer
def Model():
  x_in = tf.keras.layers.Input(shape=(1,))
  x_out = TestLayer()(x_in)
  return tf.keras.Model(inputs=x_in, outputs=x_out, name="fp8_test")

# Create toy dataset, want to learn `scale` such to satisfy 5 = 2 * scale (i.e, `scale` should learn ~2.5)
def Dataset():
  inps = tf.ones(shape=(10**5,)) * 2 # inputs 
  expected = tf.ones(shape=(10**5,)) * 5 # targets 
  data_in = tf.data.Dataset.from_tensors(inps)
  data_exp = tf.data.Dataset.from_tensors(expected)
  dataset = tf.data.Dataset.zip((data_in, data_exp))
  return dataset

model = Model()
model.summary()

dataset = Dataset()

# Use `MSE` loss and `SGD` optimizer
model.compile(
    loss=tf.keras.losses.MSE,
    optimizer=tf.keras.optimizers.SGD(),
)

model.fit(dataset, epochs=100)

This is failing with the following shape related error in the optimizer:

ValueError: Shapes must be equal rank, but are 1 and 2 for '{{node SGD/SGD/update/ResourceApplyGradientDescent}} = ResourceApplyGradientDescent[T=DT_FLOAT, use_locking=true](fp8_test/test_layer_1/ReadVariableOp/resource, SGD/Identity, SGD/IdentityN)' with input shapes: [], [], [100000,1].

I’ve been staring at the docs for a while, I’m a bit stumped as to why this isn’t working, I would really appreciate any input on how to fix this toy example.

Thanks in advance.

I attempted the above toy example.