Changing tf.Variable shape inside a model's call method

Hi, I am still working with a special recurrent model, which requires me (as far as I know; if you have another idea let me know) to store previous outputs of a batch as a buffer. In general, the size of the last batch of my dataset will have a smaller size than the nominal batch size. Due to this, if I do not take special precautions (e.g. dropping the last batch), I get a run time error. I am looking for a way to adapting the buffer size depending on the current batch size.

My code is as follows:

class FRAE(tf.keras.Model):
    def __init__(self, latent_dim, shape, ht, n1, n2, n3, n4, batch_size=1, bypass=True, trainable=True,**kwargs):
        super(FRAE, self).__init__(**kwargs)
        self.latent_dim = latent_dim
        self.shape = shape
        self.ht = ht
        self.batch_size = batch_size

        self.SetupBuffer(batch_size, shape[0], ht)
        self.bypass = bypass
        self.trainable = trainable

        self.l1 = tf.keras.layers.Dense(n1, activation='tanh', input_shape=shape)
        self.l2 = tf.keras.layers.Dense(n2, activation='tanh') 
        self.ls = tf.keras.layers.Dense(latent_dim, activation='swish')

        self.l3 = tf.keras.layers.Dense(n3, activation='tanh')
        self.l4 = tf.keras.layers.Dense(n4, activation='tanh')
        self.l5 = tf.keras.layers.Dense(shape[-1], activation='linear')



    def SetupBuffer(self, batch_size, input_dim, ht):
        self.buffer = tf.Variable(initial_value=tf.zeros(shape=(batch_size, input_dim * ht), dtype=tf.float32), trainable=False)



    def call(self, x):
        if self.bypass:
            return x
        decoded = tf.TensorArray(tf.float32, size=tf.shape(x)[1])
        for i in tf.range(tf.shape(x)[1]):
            xslice = x[:,i,:]
            xin = tf.concat((xslice, self.buffer), axis=1)
            encoded = self.ls(self.l2(self.l1(xin)))

            decin = tf.concat([encoded, self.buffer], axis=1)
            y = self.l5(self.l4(self.l3(decin)))
            decoded = decoded.write(i,y)
            i += 1
            self.buffer.assign(tf.concat([y, self.buffer[:, :-self.shape[0]]], axis=1))


        tmp = tf.transpose(decoded.stack(),[1,0,2])
        return tmp
    

I would like to do something like

def call(self,x):
    self.buffer = self.SetupBuffer(tf.shape(x)[0], self.shape[0], self.ht)

However, this does not run, because self.buffer cannot be set as is within the call method. All approaches I tried did not succeed. Is there some way I can dynamically adjust the shape/size of self.buffer to match x in dimension 0?

My only solution is to work with callbacks during training, which could call SetupBuffer right before a batch.

Any ideas are welcome.

The root cause of the issue you’re facing with changing the tf.Variable shape inside a model’s call method in TensorFlow is that TensorFlow’s execution model requires tensor shapes to be consistent and known ahead of time for graph optimization. tf.Variable objects are meant to have fixed shapes once created, making them unsuitable for direct resizing within the call method during model execution.

Thank you. So it is not possible to adjust the buffer size dynamically? Do you perhaps see some other way to store the output of the model so that I can feed it back to the input? Maybe if I do things differently I can solve the problem.

Tim_Wolfe kindly helped me to change my model such that the tf.variable is no longer necessary. Apparently, this works:

class StatefulBufferLayer(tf.keras.layers.Layer):
    def __init__(self, output_dim,   **kwargs):
        super(StatefulBufferLayer, self).__init__(**kwargs)
        self.l1 = tf.keras.layers.Dense(10, activation='tanh')
        self.l2 = tf.keras.layers.Dense(10, activation='tanh')  # has to be fixed to n2
        self.ls = tf.keras.layers.Dense(5, activation='swish')
        self.l3 = tf.keras.layers.Dense(10, activation='tanh')
        self.l4 = tf.keras.layers.Dense(10, activation='tanh')
        self.l5 = tf.keras.layers.Dense(output_dim, activation='linear')

    def call(self, inputs, state):
        # `inputs` are the current batch inputs
        # `state` is the previous output you want to incorporate

        # Implement your logic here to combine `inputs` and `state`
        # For example, you could concatenate them along a specific dimension
        combined_inputs = tf.concat([inputs, state], axis=-1)
        encoded = self.ls(self.l2(self.l1(combined_inputs)))

        decin = tf.concat([encoded, state], axis=1)
        y = self.l5(self.l4(self.l3(decin)))

        # Perform operations on `combined` as needed

        # Return the new output and state
        new_output = y  # Result of some operations on `combined`
        new_state  = tf.concat([new_output, state[:, tf.shape(new_output)[-1]:]], axis=-1)  # Decide how you want to update the state
        return new_output, new_state


class FRAE(tf.keras.Model):
    def __init__(self, shape, ht,   **kwargs):
        super(FRAE, self).__init__(**kwargs)
        self.shape = shape
        self.ht = ht
        self.stateful_buffer_layer = StatefulBufferLayer(shape)


    def call(self, x):
        state = tf.zeros(shape=(tf.shape(x)[0], tf.shape(x)[2] * self.ht))

        x_unstacked = tf.unstack(x, axis=1)

        output_list = []
        for inputs in x_unstacked:
            new_output, state = self.stateful_buffer_layer(inputs, state)
            output_list.append(new_output)

        outputs = tf.stack(output_list, axis=1)
        return outputs

Maybe this helps someone else some day.