Training Variational Autoencoder (VAE) from multiple input tensors

Hi team,

I am new to the VAE implementation. And I want to predict SINR (RBG) image from multiple tensors such as Euclidean distance image, 3D distance image, Permittivity image, and Conductivity image (all in RBG).

Please let me know if there is any sample work done in such topic, or a headstart that anyone can give me. I could not find any relevant work online that could help me with the architecture.

Looking forward to any kind of suggestion.

Thank You,
Rahul

I have been looking at python Functional APIs. I am attaching a few relevant links below for future reference. This helps us in concatenation multiple inputs to predict multiple outputs, which is something I am looking for.

Link-1: The Functional API  |  TensorFlow Core

Link-2: https://machinelearningmastery.com/keras-functional-api-deep-learning/

To predict an SINR image from multiple input tensors like Euclidean distance, 3D distance, permittivity, and conductivity images using a Variational Autoencoder (VAE), you should create a multi-input VAE architecture. This involves separate input layers for each type of image, feature extraction layers for each input, a fusion layer to concatenate all features, and a typical VAE encoder-decoder structure with a latent space. The model would use convolutional layers for feature extraction and transposed convolutional layers in the decoder to reconstruct the target SINR image. Training involves optimizing both reconstruction loss and the KL divergence to ensure effective learning and generation. This approach requires experimentation with the model’s architecture, hyperparameters, and training procedures to achieve the desired results.

Thank you @Tim_Wolfe for the suggestion. I was able to successfully formulate my encoder and decoder with multiple inputs.

To compute the overall loss from the two input tensors, I’ve implemented the following steps:

# Calculate reconstruction loss for permittivity_input
reconstruction_loss_perm = tf.reduce_mean(tf.reduce_sum(tf.keras.losses.binary_crossentropy(permittivity_input, vae_output), axis=(1, 2)))

# Calculate reconstruction loss for APloc_input
reconstruction_loss_APloc = tf.reduce_mean(tf.reduce_sum(tf.keras.losses.binary_crossentropy(APloc_input, vae_output), axis=(1, 2)))

# Compute KL divergence loss
kl_loss = -0.5 * tf.reduce_mean(tf.reduce_sum(1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var), axis=1))

# Total VAE loss
vae_loss = reconstruction_loss_perm + reconstruction_loss_APloc + kl_loss

# Add the total VAE loss to the model
vae.add_loss(vae_loss)

However, I’ve encountered the following error: TypeError: unhashable type: ‘DictWrapper’ at the line where I add the loss to the VAE model (vae.add_loss(vae_loss)).

I would greatly appreciate any suggestions or insights on resolving this issue.

Thank you,
Rahul

[Best Guess]

To resolve the TypeError: unhashable type: ‘DictWrapper’ when adding the VAE loss, ensure all inputs and operations are compatible with TensorFlow expectations, inspect the vae_loss and its components for type issues, simplify the loss addition to test functionality, and confirm you’re using an up-to-date TensorFlow version. If the issue persists, check for custom layers or operations in your model that might be causing the problem.

I’m sharing this update for the benefit of others who may encounter a similar issue in their work.

After some troubleshooting, I managed to resolve the issue. It turned out that the reconstruction loss was functioning correctly, but adjustments were needed during the training phase. Specifically, I had to reshape the training and testing datasets as demonstrated below:

X_train_reshaped = X_train.reshape(-1, 120, 160, 1)
X_test_reshaped = X_test.reshape(-1, 120, 160, 1)

batch_size = 120
epochs = 10

# In Autoencoder we will fit the training data to itself. 
# Train the model
history = vae.fit(
    [X_train_reshaped, X_train_reshaped],  # Input data
    epochs=epochs,                          # Number of epochs
    batch_size=batch_size,                  # Batch size
    validation_data=([X_test_reshaped, X_test_reshaped], None)  # Validation data
)

I hope this clarification helps others facing a similar challenge.

Thank you,
Rahul

I’d like to extend my gratitude to @Tim_Wolfe for his valuable suggestions. They proved to be incredibly helpful in addressing the issue at hand.