Tensorflow - Save & Load custom model in Keras v3 format

I have problems saving and loading models composed of custom layers.

A good example of this kind of model in the VAE example in tensorflow official documentation at https://www.tensorflow.org/guide/keras/making_new_layers_and_models_via_subclassing#putting_it_all_together_an_end-to-end_example

In fact, there is no error when I save it:

vae.save("vae211.keras")

But when I load it :

vae_ = tf.keras.models.load_model("vae211.keras")

I have the following error:

TypeError: Could not locate class ‘VariationalAutoEncoder’. Make sure
custom classes are decorated with
@keras.saving.register_keras_serializable(). Full object config:
{‘module’: None, ‘class_name’: ‘VariationalAutoEncoder’, ‘config’:
{‘name’: ‘autoencoder’, ‘trainable’: True, ‘dtype’: ‘float32’,
‘img_d’: 784, ‘hidden_d’: 128, ‘latent_d’: 32}, ‘registered_name’:
‘Custom>VariationalAutoEncoder’, ‘build_config’: {‘input_shape’: [64,
784]}}

I tried to add get_config() methods but nothing worked.

Please find the full code below if you can help :

import tensorflow as tf

img_d = 784
hidden_d = 128
latent_d = 32
epochs = 2

"""
        Dataset
"""


(x_tra, _), _ = tf.keras.datasets.mnist.load_data()
x_tra = x_tra.reshape(60000, 784).astype("float32") / 255

tra_ds = tf.data.Dataset.from_tensor_slices(x_tra)
tra_ds = tra_ds.shuffle(buffer_size=1024).batch(64)

"""
        Model
"""

@tf.keras.saving.register_keras_serializable()
class Sampling(tf.keras.layers.Layer):

    def call(self, z_mean, z_log_var):
        bs, latent_dim = tf.shape(z_mean)
        epsilon = tf.random.normal(shape=(bs, latent_dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon


@tf.keras.saving.register_keras_serializable()
class Encoder(tf.keras.layers.Layer):
    
    def __init__(self, latent_d=32, hidden_d=64, name="encoder", **kwargs):
        super().__init__(name=name, **kwargs)
        self.dense1 = tf.keras.layers.Dense(hidden_d, activation="relu")
        self.dense_mean = tf.keras.layers.Dense(latent_d)
        self.dense_log_var = tf.keras.layers.Dense(latent_d)
        self.sampling = Sampling()

    def call(self, inputs):
        x = self.dense1(inputs)
        z_mean = self.dense_mean(x)
        z_log_var = self.dense_log_var(x)
        z = self.sampling(z_mean, z_log_var)
        return z_mean, z_log_var, z


@tf.keras.saving.register_keras_serializable()
class Decoder(tf.keras.layers.Layer):
    def __init__(self, img_d, hidden_d=64, name="decoder", **kwargs):
        super().__init__(name=name, **kwargs)
        self.dense1 = tf.keras.layers.Dense(hidden_d, activation="relu")
        self.dense2 = tf.keras.layers.Dense(img_d, activation="sigmoid")

    def call(self, inputs):
        x = self.dense1(inputs)
        return self.dense2(x)


@tf.keras.saving.register_keras_serializable()
class VariationalAutoEncoder(tf.keras.Model):
    def __init__(self, img_d, hidden_d=64, latent_d=32,
                 name="autoencoder", **kwargs):
        super().__init__(name=name, **kwargs)
        self.img_d = img_d
        self.encoder = Encoder(latent_d=latent_d, hidden_d=hidden_d)
        self.decoder = Decoder(img_d=img_d, hidden_d=hidden_d)

    def call(self, inputs):
        z_mean, z_log_var, z = self.encoder(inputs)
        reconstructed = self.decoder(z)
        kl_loss = -0.5 * tf.reduce_mean(
            z_log_var - tf.square(z_mean) - tf.exp(z_log_var) + 1
        )
        self.add_loss(kl_loss)
        return reconstructed


vae = VariationalAutoEncoder(
    img_d=img_d, hidden_d=hidden_d, latent_d=latent_d
)

optimizer = tf.keras.optimizers.Adam(1e-3)
mse_loss_fn = tf.keras.losses.MeanSquaredError()

"""
        Fitting
"""
loss_metric = tf.keras.metrics.Mean()

for epoch in range(epochs):
    print(f"Start of {epoch = }")

    for step, x in enumerate(tra_ds):
        with tf.GradientTape() as tape:
            reconstructed = vae(x)
            loss = mse_loss_fn(x, reconstructed)
            loss += sum(vae.losses)

        grads = tape.gradient(loss, vae.trainable_weights)
        optimizer.apply_gradients(zip(grads, vae.trainable_weights))

        loss_metric(loss)

        if step % 100 == 0:
            print(f"{step=}: mean loss = {loss_metric.result():.4f}")

"""
        Save & Load
"""

vae.save("vae211.keras")
vae_ = tf.keras.models.load_model("vae211.keras") # <===== Error

decoder_ = vae.decoder
...