Impossible to save custom keras model

I have a custom model based on the keras.Model class, which I am trying to save for later use. While saving the model works without a problem, reloading the same model throws the following error:

ValueError: Cannot assign value to variable ' dense_4/kernel:0': Shape mismatch.The variable shape (24, 512), and the assigned value shape (784, 256) are incompatible.

I am using the keras file format and the corresponding model.save and keras.saving.load_model methods.
I also tried pickle and dill, both failed to save the model.

Everything above the class Encoder is the math for the loss calculation, which can be ignored(probably) but need to be there for functionality. The model is a kind of variational autoencoder trained on MNIST Data.
The BaseNet class is trained first. The HigherNet receives the Encoder and Decoder from the trained BaseNet and is trained again with its own generator network. This HigherNet is than saved and this is the model I am trying to load again. The BaseNet in itself is not relevant anymore beyond its used parts in the HigherNet.

Working example for error reproduction:

import tensorflow as tf
import keras
from keras import layers
import numpy as np
import math


def silverman_rule_of_thumb_normal(N):
    return tf.pow((4 / (3 * N)), 0.4)


def pairwise_distances(x, y=None):
    if y is None:
        y = x
    distances_tf = tf.norm(x[:, None] - y, axis=-1) ** 2
    return tf.cast(distances_tf, dtype=tf.float64)


def cw_normality(X, y=None):
    assert len(X.shape) == 2

    D = tf.cast(tf.shape(X)[1], tf.float64)
    N = tf.cast(tf.shape(X)[0], tf.float64)

    if y is None:
        y = silverman_rule_of_thumb_normal(N)

    # adjusts for dimensionality; D=2 -> K1=1, D>2 -> K1<1
    K1 = 1.0 / (2.0 * D - 3.0)

    A1 = pairwise_distances(X)
    A = tf.reduce_mean(1 / tf.math.sqrt(y + K1 * A1))

    B1 = tf.cast(tf.square(tf.math.reduce_euclidean_norm(X, axis=1)), dtype=tf.float64)
    B = 2 * tf.reduce_mean((1 / tf.math.sqrt(y + 0.5 + K1 * B1)))

    return (1 / tf.sqrt(1 + y)) + A - B


def phi_sampling(s, D):
    return tf.pow(1.0 + 4.0 * s / (2.0 * D - 3), -0.5)


def cw_sampling_lcw(first_sample, second_sample, y):
    shape = first_sample.get_shape().as_list()
    dim = np.prod(shape[1:])
    first_sample = tf.reshape(first_sample, [-1, dim])

    shape = second_sample.get_shape().as_list()
    dim = np.prod(shape[1:])
    second_sample = tf.reshape(second_sample, [-1, dim])

    assert len(first_sample.shape) == 2
    assert first_sample.shape == second_sample.shape

    _, D = first_sample.shape

    T = 1.0 / (2.0 * tf.sqrt(math.pi * y))

    A0 = pairwise_distances(first_sample)
    A = tf.reduce_mean(phi_sampling(A0 / (4 * y), D))

    B0 = pairwise_distances(second_sample)
    B = tf.reduce_mean(phi_sampling(B0 / (4 * y), D))

    C0 = pairwise_distances(first_sample, second_sample)
    C = tf.reduce_mean(phi_sampling(C0 / (4 * y), D))

    return T * (A + B - 2 * C)


def euclidean_norm_squared(X, axis=None):
    return tf.reduce_sum(tf.square(X), axis=axis)


def cw_sampling_silverman(first_sample, second_sample):
    stddev = tf.math.reduce_std(second_sample)
    N = tf.cast(tf.shape(second_sample)[0], tf.float64)
    gamma = silverman_rule_of_thumb_normal(N)
    return cw_sampling_lcw(first_sample, second_sample, gamma)


@tf.keras.saving.register_keras_serializable()
class Encoder(keras.Model):
    def __init__(self, args, **kwargs):
        super().__init__(**kwargs)
        self.activation = layers.Activation("relu")
        self.flatten = layers.Flatten()
        self.dense1 = layers.Dense(256)
        self.dense2 = layers.Dense(args["latent_dim"], name="z")

    def build(self, **kwargs):
        encoder_inputs = keras.Input(shape=(28, 28, 1))
        x = self.flatten(encoder_inputs)
        x = self.dense1(x)
        x = self.activation(x)
        z = self.dense2(x)
        encoder = keras.Model(encoder_inputs, [z], name="encoder")
        return encoder


@tf.keras.saving.register_keras_serializable()
class Decoder(keras.Model):
    def __init__(self, args, **kwargs):
        super().__init__(**kwargs)
        self.latent_dim = args["latent_dim"]
        self.activation = layers.Activation("relu")
        self.dense1 = layers.Dense(256)
        self.dense2 = layers.Dense(28 * 28, activation="sigmoid")
        self.reshape = layers.Reshape([28, 28, 1])

    def build(self, **kwargs):
        latent_inputs = keras.Input(shape=(self.latent_dim,))
        x = self.dense1(latent_inputs)
        x = self.activation(x)
        x = self.dense2(x)
        decoder_outputs = self.reshape(x)
        decoder = keras.Model(latent_inputs, decoder_outputs, name="encoder")
        return decoder


@tf.keras.saving.register_keras_serializable()
class Generator(keras.Model):
    def __init__(self, args, **kwargs):
        super().__init__(**kwargs)
        self.noise_dim = args["noise_dim"]
        self.activation = layers.Activation("relu")
        self.dense1 = layers.Dense(512)
        self.dense2 = layers.Dense(args["latent_dim"], name="z")

    def build(self, **kwargs):
        noise_inputs = keras.Input(shape=(self.noise_dim,))
        x = self.dense1(noise_inputs)
        x = self.activation(x)
        z = self.dense2(x)
        latent_generator = keras.Model(noise_inputs, [z], name="generator")
        return latent_generator


@tf.keras.saving.register_keras_serializable()
class BaseNet(keras.Model):
    def __init__(self, args, **kwargs):
        super(BaseNet, self).__init__(**kwargs)
        self.encoder = Encoder(args).build()
        self.decoder = Decoder(args).build()
        self.args = args
        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = keras.metrics.Mean(
            name="cw_reconstruction_loss"
        )
        self.cw_loss_tracker = keras.metrics.Mean(name="cw_loss")

    def get_config(self):
        config = {
            "args": self.args
        }
        base_config = super(BaseNet, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.cw_loss_tracker,
        ]

    def train_step(self, data):
        with tf.GradientTape() as tape:
            z = self.encoder(data)
            reconstruction = self.decoder(z)
            # tf.print(reconstruction)
            cw_reconstruction_loss = tf.math.log(
                cw_sampling_silverman(data, reconstruction))
            lambda_val = 1
            cw_loss = lambda_val * tf.math.log(cw_normality(z))
            total_loss = cw_reconstruction_loss + cw_loss
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(cw_reconstruction_loss)
        self.cw_loss_tracker.update_state(cw_loss)
        return {
            "total_loss": self.total_loss_tracker.result(),
            "cw_reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "cw_loss": self.cw_loss_tracker.result(),
        }


@tf.keras.saving.register_keras_serializable()
class HighNet(keras.Model):
    def __init__(self, encoder, decoder, args, **kwargs):
        super(HighNet, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.args = args
        self.generator = Generator(args).build()
        self.reconstruction_loss_tracker = keras.metrics.Mean(
            name="cw_reconstruction_loss"
        )

    def get_config(self):
        config = {
            "encoder": self.encoder,
            "decoder": self.decoder,
            "args": self.args
        }
        base_config = super(HighNet, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

    def call(self, inputs, **kwargs):
        x = self.encoder(inputs)
        return self.decoder(x)

    @property
    def metrics(self):
        return [
            self.reconstruction_loss_tracker
        ]

    def train_step(self, data):
        with tf.GradientTape() as tape:
            z = self.encoder(data)
            batch_size = tf.shape(z)[0]
            noise_np = np.random.normal(0, 1, size=self.args["noise_dim"])
            noise_tf = tf.expand_dims(tf.convert_to_tensor(noise_np), axis=0)
            noise_tf = tf.repeat(noise_tf, repeats=batch_size, axis=0)
            noise_z = self.generator(noise_tf)
            # tf.print(reconstruction)
            cw_reconstruction_loss = tf.math.log(
                cw_sampling_silverman(z, noise_z))
        grads = tape.gradient(cw_reconstruction_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        self.reconstruction_loss_tracker.update_state(cw_reconstruction_loss)
        return {
            "cw_reconstruction_loss": self.reconstruction_loss_tracker.result()
        }


def test_saving():
    args = {"sample_amount": 1000,
            "latent_dim": 24,
            "noise_dim": 24,
            "epochs": 1,
            "batch_size": 128,
            "patience": 3,
            "learning_rate": 0.0001}
    (x_train, y_train), (x_test, _) = keras.datasets.mnist.load_data()
    mnist_digits = np.concatenate([x_train, x_test], axis=0)[0:100]
    mnist_digits = np.expand_dims(mnist_digits, -1).astype("float32") / 255
    base_model = BaseNet(args)
    base_model.compile(optimizer=keras.optimizers.Adam(learning_rate=args["learning_rate"]))
    es_callback = keras.callbacks.EarlyStopping(monitor='total_loss', patience=args["patience"], mode="min")
    base_model.fit(mnist_digits, epochs=args["epochs"], batch_size=args["batch_size"], callbacks=[es_callback])

    model = HighNet(base_model.encoder, base_model.decoder, args)
    model.compile(optimizer=keras.optimizers.Adam(learning_rate=args["learning_rate"]))
    es2_callback = keras.callbacks.EarlyStopping(monitor='cw_reconstruction_loss', patience=args["patience"],
                                                 mode="min")
    model.fit(mnist_digits, epochs=args["epochs"], batch_size=args["batch_size"], callbacks=[es2_callback])

    model.save("high_model.keras", save_format="keras")

    loaded_model = keras.saving.load_model("high_model.keras")


if __name__ == "__main__":
    test_saving()

Output of model.summary after saving:

Model: "encoder"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_1 (InputLayer)        [(None, 28, 28, 1)]       0         
                                                                 
 flatten (Flatten)           (None, 784)               0         
                                                                 
 dense (Dense)               (None, 256)               200960    
                                                                 
 activation (Activation)     (None, 256)               0         
                                                                 
 z (Dense)                   (None, 24)                6168      
                                                                 
=================================================================
Total params: 207128 (809.09 KB)
Trainable params: 207128 (809.09 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________


Model: "encoder"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_2 (InputLayer)        [(None, 24)]              0         
                                                                 
 dense_1 (Dense)             (None, 256)               6400      
                                                                 
 activation_1 (Activation)   (None, 256)               0         
                                                                 
 dense_2 (Dense)             (None, 784)               201488    
                                                                 
 reshape (Reshape)           (None, 28, 28, 1)         0         
                                                                 
=================================================================
Total params: 207888 (812.06 KB)
Trainable params: 207888 (812.06 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________


Model: "generator"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_3 (InputLayer)        [(None, 24)]              0         
                                                                 
 dense_3 (Dense)             (None, 512)               12800     
                                                                 
 activation_2 (Activation)   (None, 512)               0         
                                                                 
 z (Dense)                   (None, 24)                12312     
                                                                 
=================================================================
Total params: 25112 (98.09 KB)
Trainable params: 25112 (98.09 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________

I don’t know why this summary looks different from my original summary, but in the original summary i can derive from the ‘dense_4/kernel:0’ that the shape mismatch happens happens in the first dense layer of the encoder. There the layers are annotated with their runtime labels, i.e. dense_4 and so on. I don’t know what I omitted that would lead a different styled summary, the original model has some more layers, batchnormalization, reused activation functions, etc.
Still the error remains the same. I can’t possibly imagine while simply saving and loading a model would change any shape incompatibilites.

I am using Tensorflow 2.15, Keras 2.15, Python 3.11 and am working on PyCharm Professional on Ubuntu 23.10.

I already asked this on StackOverflow, but I thought this place might be more specialized on keras problems.

While I was trying to debug I narrowed the error message down to:
ValueError: Layer 'dense_3' expected 0 variables, but received 2 variables during loading. Expected: []
Which was weird, because there was no dense_3 layer. I managed to build the following minimal example:

import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Activation, Input, BatchNormalization, Dropout, Flatten, Identity


args = {"activation": "relu",
        "batch_norm": True}

@keras.saving.register_keras_serializable()
class CustomModel1(Model):
    def __init__(self):
        super().__init__()
        self.dense = Dense(32)

    def call(self, inputs):
        x = self.dense(inputs)
        return x


@keras.saving.register_keras_serializable()
class CustomModel2(Model):
    def __init__(self):
        super().__init__()
        self.dense = Dense(32)

    def call(self, inputs):
        x = self.dense(inputs)
        return x


@keras.saving.register_keras_serializable()
class CustomModel3(Model):
    def __init__(self):
        super().__init__()
        self.net1 = CustomModel1()
        self.net2 = CustomModel2()

    def call(self, inputs):
        z = self.net1(inputs)
        x = self.net2(z)
        return z, x

    def train_step(self, data):
        x, y = data

        with tf.GradientTape() as tape:
            # z, y_pred = self(x)                 # this fixes it instead
            y_pred = self.net2(self.net1(x))      # this line throws the error
            loss = self.compiled_loss(y, y_pred)

        trainable_vars = self.trainable_weights
        gradients = tape.gradient(loss, trainable_vars)
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        self.compiled_metrics.update_state(y, y_pred)
        return {m.name: m.result() for m in self.metrics}


# Instantiate the model
model = CustomModel3()

# Compile the model
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# Create some dummy data for training
x_train = np.random.random((1000, 32))
y_train = np.random.randint(10, size=(1000,))

# Train the model for one epoch
model.fit(x_train, y_train, epochs=1)

# Save the model
model.save('custom_model.keras', save_format='keras')

# Load the model again
loaded_model = tf.keras.models.load_model('custom_model.keras')

# Generate some sample data for prediction
x_sample = np.random.random((10, 32))  # Assuming 10 samples with 32 features each

# Make predictions using the loaded model
predictions = loaded_model.predict(x_sample)
print(predictions)
# Print the predictions
print(model.summary())

Calling the submodels by themselves over self.net1 and self.net2 inside the train_step would throw the error, but having them called by the call method of the higher model and returning their respective values does not throw an error.
How this could lead to this error messages is beyond me.