Custom Quant. on Conv Layers

Hello, I want to apply custom quantization on Conv Layers, not on the Axis/Channel (by default). Therefore, following script creates a small TFLite (MNIST) Model for test purposes where the Conv layers are custom quantized.

from tensorflow_model_optimization.python.core.quantization.keras import quantizers
import tensorflow_model_optimization as tfmot
import tensorflow as tf
import numpy as np
import os

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape((x_train.shape[0], 28, 28, 1))
x_test = x_test.reshape((x_test.shape[0], 28, 28, 1))
# https://www.tensorflow.org/model_optimization/guide/quantization/training_comprehensive_guide



class Default8BitQuantizeConfig(tfmot.quantization.keras.QuantizeConfig):
  """QuantizeConfig for non recurrent Keras layers."""

  def __init__(self, weight_attrs, activation_attrs, quantize_output):
    self.weight_attrs = weight_attrs
    self.activation_attrs = activation_attrs
    self.quantize_output = quantize_output

    # TODO(pulkitb): For some layers such as Conv2D, per_axis should be True.
    # Add mapping for which layers support per_axis.
    self.weight_quantizer = quantizers.LastValueQuantizer(
        num_bits=8, per_axis=False, symmetric=True, narrow_range=True)
    self.activation_quantizer = quantizers.MovingAverageQuantizer(
        num_bits=8, per_axis=False, symmetric=False, narrow_range=False)

  def get_weights_and_quantizers(self, layer):
    return [(getattr(layer, weight_attr), self.weight_quantizer)
            for weight_attr in self.weight_attrs]

  def get_activations_and_quantizers(self, layer):
    return [(getattr(layer, activation_attr), self.activation_quantizer)
            for activation_attr in self.activation_attrs]

  def set_quantize_weights(self, layer, quantize_weights):
    if len(self.weight_attrs) != len(quantize_weights):
      raise ValueError(
          '`set_quantize_weights` called on layer {} with {} '
          'weight parameters, but layer expects {} values.'.format(
              layer.name, len(quantize_weights), len(self.weight_attrs)))

    for weight_attr, weight in zip(self.weight_attrs, quantize_weights):
      current_weight = getattr(layer, weight_attr)
      if current_weight.shape != weight.shape:
        raise ValueError('Existing layer weight shape {} is incompatible with'
                         'provided weight shape {}'.format(
                             current_weight.shape, weight.shape))

      setattr(layer, weight_attr, weight)

  def set_quantize_activations(self, layer, quantize_activations):
    if len(self.activation_attrs) != len(quantize_activations):
      raise ValueError(
          '`set_quantize_activations` called on layer {} with {} '
          'activation parameters, but layer expects {} values.'.format(
              layer.name, len(quantize_activations),
              len(self.activation_attrs)))

    for activation_attr, activation in \
        zip(self.activation_attrs, quantize_activations):
      setattr(layer, activation_attr, activation)

  def get_output_quantizers(self, layer):
    if self.quantize_output:
      return [self.activation_quantizer]
    return []

  @classmethod
  def from_config(cls, config):
    """Instantiates a `Default8BitQuantizeConfig` from its config.
    Args:
        config: Output of `get_config()`.
    Returns:
        A `Default8BitQuantizeConfig` instance.
    """
    return cls(**config)

  def get_config(self):
    # TODO(pulkitb): Add weight and activation quantizer to config.
    # Currently it's created internally, but ideally the quantizers should be
    # part of the constructor and passed in from the registry.
    return {
        'weight_attrs': self.weight_attrs,
        'activation_attrs': self.activation_attrs,
        'quantize_output': self.quantize_output
    }

  def __eq__(self, other):
    if not isinstance(other, Default8BitQuantizeConfig):
      return False

    return (self.weight_attrs == other.weight_attrs and
            self.activation_attrs == self.activation_attrs and
            self.weight_quantizer == other.weight_quantizer and
            self.activation_quantizer == other.activation_quantizer and
            self.quantize_output == other.quantize_output)

  def __ne__(self, other):
    return not self.__eq__(other)
  
class Default8BitConvWeightsQuantizer(quantizers.LastValueQuantizer):
  """Quantizer for handling weights in Conv2D/DepthwiseConv2D layers."""

  def __init__(self):
    """Construct LastValueQuantizer with params specific for TFLite Convs."""

    super(Default8BitConvWeightsQuantizer, self).__init__(
        num_bits=8, per_axis=False, symmetric=True, narrow_range=True)

  def build(self, tensor_shape, name, layer):
    min_weight = layer.add_weight(
        name + '_min',
        shape=None,
        initializer=tf.keras.initializers.Constant(-6.0),
        trainable=False)
    max_weight = layer.add_weight(
        name + '_max',
        shape=None,
        initializer=tf.keras.initializers.Constant(6.0),
        trainable=False)

    return {'min_var': min_weight, 'max_var': max_weight}
  
class CustomDefault8BitConvQuantizeConfig(Default8BitQuantizeConfig):
  """QuantizeConfig for Conv2D/DepthwiseConv2D layers."""

  def __init__(self, weight_attrs, activation_attrs, quantize_output):
    super(CustomDefault8BitConvQuantizeConfig,
          self).__init__(weight_attrs, activation_attrs, quantize_output)

    self.weight_quantizer = Default8BitConvWeightsQuantizer()
    

def setup_model():
  quantize_annotate_model = tfmot.quantization.keras.quantize_annotate_model
  quantize_annotate_layer = tfmot.quantization.keras.quantize_annotate_layer
  model = quantize_annotate_model(tf.keras.Sequential([
         tf.keras.layers.InputLayer(input_shape=(28, 28, 1)),
         quantize_annotate_layer(tf.keras.layers.Conv2D(64, kernel_size = (3, 3), padding = 'same', activation='relu'),quantize_config=CustomDefault8BitConvQuantizeConfig(['kernel'], ['activation'],True)),
         quantize_annotate_layer(tf.keras.layers.Conv2D(32, kernel_size = (3, 3), padding = 'same', activation='relu'),quantize_config=CustomDefault8BitConvQuantizeConfig(['kernel'], ['activation'],True)),
         tf.keras.layers.Dropout(0.5),
         quantize_annotate_layer(tf.keras.layers.Conv2D(16, kernel_size = (3, 3), padding = 'same', activation='relu'),quantize_config=CustomDefault8BitConvQuantizeConfig(['kernel'], ['activation'],True)),
         tf.keras.layers.Dropout(0.25),
         tf.keras.layers.Flatten(),
         tf.keras.layers.Dense(10)
  ]))
  quantize_scope = tfmot.quantization.keras.quantize_scope
  with quantize_scope(
    {'CustomDefault8BitConvQuantizeConfig': CustomDefault8BitConvQuantizeConfig}):
    # Use `quantize_apply` to actually make the model quantization aware.
    quant_aware_model = tfmot.quantization.keras.quantize_apply(model)
  return quant_aware_model

def setup_pretrained_weights():
  model= setup_model()
  
  
  checkpoint_path = "training_1_custom_conv_q_layer/cp.ckpt"
  checkpoint_dir = os.path.dirname(checkpoint_path)

  # Create a callback that saves the model's weights
  cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                  save_weights_only=True,
                                                  verbose=1)
  model.compile(optimizer="adam", loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy'])
  model.fit(x_train, y_train,epochs=1, validation_split=0.1, batch_size=1000,callbacks=[cp_callback])
  
  return model

def train_model():
  model = setup_model()
  pretrained_weights = setup_pretrained_weights()
  return pretrained_weights

def setup_pretrained_model(checkpoint_dir:str):
  model = setup_model()
  latest = tf.train.latest_checkpoint(checkpoint_dir)
  model.load_weights(latest)
  return model

def representative_data_gen():
  for input_value in tf.data.Dataset.from_tensor_slices(x_train.astype(np.float32)).batch(1).take(100):
    # Model has only one input so each data point has one element.
    yield [input_value]

trained_model = train_model()
converter = tf.lite.TFLiteConverter.from_keras_model(trained_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
quantized_tflite_model = converter.convert()

with open('conv_layer_quant.tflite', 'wb') as f:
  f.write(quantized_tflite_model)

It works on edge device CPU, but fails inferencing on a the device using libvx_delegate.so to use the NPU with the following Code :

delegate = [tflite.load_delegate("/usr/lib/libvx_delegate.so")]
interpreter = tflite.Interpreter(model_path=modelName,experimental_delegates=delegate)
interpreter.allocate_tensors()
interpreter.set_tensor(input_details[0]['index'],np.expand_dims(np.array(image),axis=-1).astype(np.float32))
interpreter.invoke() # <- throws

throws the following error :

Vx delegate: allowed_builtin_code set to 0.
Vx delegate: error_during_init set to 0.
Vx delegate: error_during_prepare set to 0.
Vx delegate: error_during_invoke set to 0.
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
W [HandleLayoutInfer:257]Op 18: default layout inference pass.
E [vsi_nn_QuantCheck:403]input_scale[0.068971894681] * weight_scale[0.000899141247] != bias_scale[0.000041766827]
E [setup_node:481]Check node[4] CONV2D fail
ERROR: Failed to verify graph

What am I doing wrong?
Thanks for the help!