Getting `ValueError: as_list() is not defined on an unknown TensorShape.` when trying to tokenize as part of the model

I am trying to do tokenization as part of my model, as it will reduce my CPU usage, and RAM, on the other hand, it will utilize my GPU more. But I am facing an issue saying ValueError: as_list() is not defined on an unknown TensorShape.

I have created a Layer called TokenizationLayer which takes care of the tokenization, and defines as:

class TokenizationLayer(Layer):
    def __init__(self, max_length, **kwargs):
        super(TokenizationLayer, self).__init__(**kwargs)
        self.max_length = max_length
        self.tokenizer = Tokenizer()

    def build(self, input_shape):
        super(TokenizationLayer, self).build(input_shape)

    def tokenize_sequences(self, x):
        # Tokenization function
        return self.tokenizer.texts_to_sequences([x.numpy()])[0]

    def call(self, inputs):
        # Use tf.py_function to apply tokenization element-wise
        sequences = tf.map_fn(lambda x: tf.py_function(self.tokenize_sequences, [x], tf.int32), inputs, dtype=tf.int32)
        # Masking step
        mask = tf.math.logical_not(tf.math.equal(sequences, 0))
        return tf.where(mask, sequences, -1)  # Using -1 as a mask value

    def compute_output_shape(self, input_shape):
        return (input_shape[0], self.max_length)  # Use self.max_length instead of trying to access shape

But it keeps giving me an error saying as_list() is not defined on an unknown TensorShape.

Here is the complete code, if you need it:

import tensorflow as tf
from tensorflow.keras.layers import Layer, Input, Embedding, LSTM, Dense, Concatenate
from tensorflow.keras.models import Model
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences

class TokenizationLayer(Layer):
    def __init__(self, max_length, **kwargs):
        super(TokenizationLayer, self).__init__(**kwargs)
        self.max_length = max_length
        self.tokenizer = Tokenizer()

    def build(self, input_shape):
        super(TokenizationLayer, self).build(input_shape)

    def tokenize_sequences(self, x):
        # Tokenization function
        return self.tokenizer.texts_to_sequences([x.numpy()])[0]

    def call(self, inputs):
        # Use tf.py_function to apply tokenization element-wise
        sequences = tf.map_fn(lambda x: tf.py_function(self.tokenize_sequences, [x], tf.int32), inputs, dtype=tf.int32)
        # Masking step
        mask = tf.math.logical_not(tf.math.equal(sequences, 0))
        return tf.where(mask, sequences, -1)  # Using -1 as a mask value

    def compute_output_shape(self, input_shape):
        return (input_shape[0], self.max_length)  # Use self.max_length instead of trying to access shape

# Build the model with the custom tokenization layer
def build_model(vocab_size, max_length):
    input1 = Input(shape=(1,), dtype=tf.string)
    input2 = Input(shape=(1,), dtype=tf.string)

    # Tokenization layer
    tokenization_layer = TokenizationLayer(max_length)
    embedded_seq1 = tokenization_layer(input1)
    embedded_seq2 = tokenization_layer(input2)

    # Embedding layer for encoding strings
    embedding_layer = Embedding(input_dim=vocab_size, output_dim=128, input_length=max_length)

    # Encode first string
    lstm_out1 = LSTM(64)(embedding_layer(embedded_seq1))

    # Encode second string
    lstm_out2 = LSTM(64)(embedding_layer(embedded_seq2))

    # Concatenate outputs
    concatenated = Concatenate()([lstm_out1, lstm_out2])

    # Dense layer for final output
    output = Dense(1, activation='relu')(concatenated)

    # Build model
    model = Model(inputs=[input1, input2], outputs=output)
    return model

string1 = "hello world"
string2 = "foo bar baz"

max_length = max(len(string1.split()), len(string2.split()))

model = build_model(vocab_size=1000, max_length=max_length)
model.summary()

labels = tf.random.normal((1, 5))
model.compile(optimizer='adam', loss='mse')
model.fit([tf.constant([string1]), tf.constant([string2])], labels, epochs=10, batch_size=1, validation_split=0.2)