Creating Custom RNN Cell in Keras

I’m just getting started with Keras, and to get myself started I am interested in building a class inheriting from keras.layers.Layer to construct a custom, simple, albeit somewhat unconventional RNN. I have clear requirements for how I want this RNN to work:

  1. Two inputs: input0 and input1, at each time step.
  2. One state: state0, at each time step.
  3. Two weights: c0 and c1.
  4. The cell is a purely linear function as follows: output0 = input1 + self.c0 * (state0 - input0), states0_new = input1 + self.c1 * (state0 - input0).

Note that there is no activation function. I’d like to create this RNN model using Keras, and then train it on timeseries data consisting of (input0, input1) pairs at each time, trying to predict (only) input0 at the next timestep from the data up until the current time.

Here is my attempted solution for this that currently gives errors:

from keras.layers import Layer, Activation, Input, RNN
from keras import backend as K

class MyCell(Layer):
    def __init__(self, **kwargs):
        super(MyCell, self).__init__(**kwargs)

    def build(self, input_shape):
        self.c0 = self.add_weight(shape=(1, 1),
                                 initializer='uniform',
                                 name='c0') #Just a single scalar weight
        self.c1 = self.add_weight(shape=(1, 1),
                                 initializer='uniform',
                                 name='c1') #Just a single scalar weight
        super(MyCell, self).build(input_shape)

    def call(self, inputs, states):
        input0, input1 = inputs[0], inputs[1]
        state0 = states[0]

        output0 = input1 + self.c0 * (state0 - input0) #To be evaluated at each time point
        states0_new = input1 + self.c1 * (state0 - input0) #To be evaluated at each time point
        
        return [output0], [states0_new]

# Let's use this cell in a RNN layer:
cell = MyCell()
inputs = Input((None, 2)) #Don't really understand the deal with None here.
rnn_layer = RNN(cell)
outputs = rnn_layer(inputs)

The error I see has to do with state, and says “Shapes must be equal rank, but are 1 and 2”.

I can’t figure out what’s wrong with what I’m doing, although I suspect it might have something to do with how Keras does “batching” and includes it as another dimension of the tensors, but honestly I’m really not sure, and I’m new to Keras. Hopefully someone can help me figure out what the issue is. Thanks.