Transfer learning and Quantization aware training. Subclassed model

I am getting the following error when trying to do quantization aware training with tensorflow 2.7:

ValueError: `to_quantize` can only either be a tf.keras Sequential or Functional model.

The error occurs when calling this method:

quantize_model = tfmot.quantization.keras.quantize_model(model)

The model is defined below. I suppose the reason is that subclassed models are not supported? I have already trained(normal training, not QAT) multiple models with the definition below. Post-training quantization works, but i would like to try quantization aware training to see if it improves performance. Is there a way to be able to do quantization aware training with the model below, or alternatively define it in another way and redo normal training.

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Dense, Flatten, Activation, Dropout
from tensorflow.keras.models import Model

class Mobilenet_v2_transfer(Model):

    def __init__(self):
        super(Mobilenet_v2_transfer, self).__init__()
        self.base = tf.keras.applications.mobilenet_v2.MobileNetV2(
            input_shape=(224, 224, 3), alpha=1.0, include_top=False, weights='imagenet',
        self.base.trainable = True
        for layer in self.base.layers[:130]:
            layer.trainable =  False
        self.flatten = Flatten()
        self.dense = Dense(1, kernel_regularizer=tf.keras.regularizers.L2(0.01)
        self.sigmoid = Activation('sigmoid')
    def call(self, x):
        x = self.base(x)
        x = self.flatten(x)
        x = self.dense(x)
        x = self.sigmoid(x)
        return x

Have you checked:

Basically current QAT API doesn’t supports subclass model with just a quantize_model API.

I’d like to recommend you change the model to be functional model if the model is simple enough. (it may requires redo normal training to make sure the model is exactly same.)

Functional model Example) models/ at f8f4845cc85ef674d6285337b55e43638039ff91 · tensorflow/models · GitHub

super().init(inputs=inputs, outputs=x) is a patten that you can create a subclass functional model which is internally same as an example here: The Functional API  |  TensorFlow Core

But it also not supports to quantize recursively as Bhack@ mentioned.

So you may have to put a flag to init to determine you quantize the self.base model manually (by calling quantize_apply) on init method. (also you have to call quantize_apply for the entire model to quantize some layers outside of self.base.) As similar to the Bhack@ mentioned link.