All zero values after tensor_scatter_nd_update in a Quantum CNN

Hi everyone,

We are trying to build a trainable Quantum Convolutional Neural Network.
To this end we are trying to create a new subclass QuantumConvolutionalLayer of the class keras.Layer.

How we are trying this now:

In the initialization we use the add_weight function to add the trainable_parameters as a weight. Then we create (still in the initialization) a quantum circuit (on 4 qubits) with pennylane that takes 4 classical inputs and the trainable weights that will be called upon.

In the call function (for the forward pass):
a 2x2 grid is moved over the entire image (mnist in our case) and everytime the 2x2 grid is processed using the quantum circuit defined in the initialization.

The outcome of the processing with this quantum circuit of such a 2x2 grid [measurement_results] is stored in a tensorflow tensor object and we aim to then store all these in one big tensor [out]. To do this we use the tensor_scatter_nd_update. Unfortunately when we look at the resulting out tensor it has all zero values, even though when we print the measurement_results of the small grids we get non zero values. Any ideas on how this can be solved?

Many thanks in advance for your help!


PS: Below you can find our code so far:

# Library installation
!pip install qiskit
!pip install pennylane 

# General imports
import pennylane as qml
import numpy as np
from numpy import pi, sqrt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import backend
from tensorflow.keras.layers import Layer
import matplotlib.pyplot as plt
from qiskit import QuantumCircuit, Aer, assemble, visualization

# Dataset
from keras.datasets import mnist

# Embedding imports
from pennylane.templates import QAOAEmbedding

# Build function to load train and test dataset
def load_dataset():
    # load dataset
    (trainX, trainY), (testX, testY) = mnist.load_data()
    # reshape dataset to have a single channel
    trainX = trainX.reshape((trainX.shape[0], 28, 28, 1))
    testX = testX.reshape((testX.shape[0], 28, 28, 1))
    # one hot encode target values
    trainY = tf.keras.utils.to_categorical(trainY)
    testY = tf.keras.utils.to_categorical(testY)
    return trainX, trainY, testX, testY

# Load train and test dataset
X_train, Y_train, X_test, Y_test = load_dataset()
num_images = 10 # set to -1 in order to keep all images
X_train = X_train[0:num_images]
X_test = X_test[0:num_images]
Y_train = Y_train[0:num_images]
Y_test = Y_test[0:num_images]

# Build a class for the trainable quantum convolutional layer that is a subclass of the keras.Layer class

class QuantumConvolutionalLayer(Layer):
  def __init__(self, device = "default.qubit", stride = 2, wires = 4, layers = 1, n_measurements = 4):
    # Inherits the initialization of the keras.Layer class
    super(QuantumConvolutionalLayer, self).__init__() 

    # Initialize the device
    self.wires = wires = qml.device(device, wires = self.wires)

    # Initialize the quantum circuit
    self.layers = layers
    self.stride = stride
    self.n_measurements = n_measurements
    self.trainable_parameters = self.add_weight("trainable_parameters", shape = QAOAEmbedding.shape(n_layers=layers, n_wires=wires), initializer = tf.keras.initializers.RandomNormal())
    # To this end, build the quantum circuit (for 1 square of stride x stride)
    @qml.qnode(device =, interface = "tf")
    def quantum_circuit(inputs, trainable_parameters = self.trainable_parameters):
      QAOAEmbedding(features = inputs, weights = trainable_parameters, wires = range(wires))
      return [qml.expval(qml.PauliZ(j)) for j in range(n_measurements)]

    #weight_shapes = {"trainable_parameters": QAOAEmbedding.shape(n_layers=self.layers, n_wires=self.wires)}
    #self.quantum_circuit = qml.qnn.KerasLayer(quantum_circuit, weight_shapes = weight_shapes, output_dim = self.n_measurements)
    self.quantum_circuit = quantum_circuit

    dtype = tf.float32 if tf.keras.backend.floatx() == tf.float32 else tf.float64

    if self.quantum_circuit.diff_method != "backprop" or self.quantum_circuit.diff_method_change:

  def build(self, input_shape):

  def call(self, inputs):
    # define forward pass
    h_in, w_in, ch_in = inputs.shape[1:]  # inputs.shape  (28, 28, 1) for MNIST
    h_out, w_out, ch_out = h_in // self.stride, w_in // self.stride, ch_in * self.n_measurements # (14, 14, 4) for MNIST and our quantum circuit filter
    out = tf.zeros((num_images, h_out, w_out, ch_out))
    for img_idx in range(num_images):
      # print(tf.rank(out))
      for j in range(0, h_in, self.stride):
        for k in range(0, w_in, self.stride):
          grid = [inputs[img_idx, j, k, 0], inputs[img_idx, j, k + 1, 0], inputs[img_idx, j + 1, k, 0], inputs[img_idx, j + 1, k + 1, 0]]
          measurement_results = self.quantum_circuit(inputs = grid, trainable_parameters = self.trainable_parameters)
          for ch in range(self.n_measurements):
              tf.tensor_scatter_nd_update(out[img_idx], tf.constant([[j//2, k//2, ch]]), [measurement_results[ch]])
    return out

quanv = QuantumConvolutionalLayer()

Dear all,

We have found the issue.
The tf.tensor_scatter_nd_update function replaces the entries at the given coordinates by the right values. It just doesn’t store them automatically. so you still have to assign it yourself.

Replacing the following line in the code:

tf.tensor_scatter_nd_update(out[img_idx], tf.constant([[j//2, k//2, ch]]), [measurement_results[ch]])

by this line solves the problem:

out = tf.tensor_scatter_nd_update(out, tf.constant([[img_idx, j//2, k//2, ch]]), [measurement_results[ch]])

Yes. tf.Tensor objects are immutable.

    for img_idx in range(num_images):
      # print(tf.rank(out))
      for j in range(0, h_in, self.stride):
        for k in range(0, w_in, self.stride):
            for ch in range*(...):
                out = tf.tensor_scatter_nd_update(...)

Ouch, a quadruple loop.

Remember TensorFlow is way more efficient if you can express things in vectorized terms.

If there’s any way to make the self.quantum_circuit run on a batch of items, you should consider using tf.image.extract_patches and then run this once over the batch of patches. tf.map_fn or tf.vectorized_map may do it.

Similarly, out = tf.tensor_scatter_nd_update(out, ...) makes a complete copy of the image, for each color channel of each pixel. Ooof. tf.map_fn or tf.vectorized_map would address this too, as it collects all the results, and then stacks them.

1 Like

Thanks a lot markdaoust!

Thanks to your comment we managed to streamline our code a lot in this vectorized way.
I’m sorry for the late reply, but still wanted to let you know that your comment helped us out quite a lot.

Kind regards