QAT does not restore accuracy of SalsaNext


I am currently working on quantization of SalsaNext, which semantically segments a 2D range image. The network is supposed to run on a platform with 8bit integer arithmetic and I am trying to evaluate the expected drop in accuracy (actually, the used metric is the mIoU / jaccard index).
I have spent some time training, optimizing and retraining the model, e.g. swapped activation and batchnorm layers to enable folding the batchnorm into the convolutions.

The literature states, that the drop for most models accounts to 1-2% in the specified metric, so that is, what I am hoping for. Unfortunately, the mIoU for the validation dataset drops from 55% to roughly 37%.

I am training with 150 epochs to get results close to what the developers of Salsanext were able to achieve. Afterwards, the batchnorm weights are frozen, folded into the conv layer (this has been validated successfully) and fake quantization is added using the tensorflow model optimization toolkit for all layers (Conv2D (3x3, 1x1, dilated), LeakyReLU, DepthToSpace, Concat2,3,4, Add, AvgPool2D and Softmax as last layer). Then fine-tuning is applied in 10 epochs with 1% of the initial learning rate (and annealing).

Here are results for some other things I already tried:
Default Float32: 55%
Quantize Conv2D only 8bit: 42%
Quantize all layers 8bit: 37%
Quantize all layers 16bit: 54%

Is this an expected drop in accuracy or am I missing something?
Any help is appreciated :slight_smile:

Used versions:
python 3.8.10
tensorflow 2.13.0
tensorflow-model-optimization 0.7.5

The significant drop in mIoU from 55% to 37% after quantizing SalsaNext to 8-bit could be due to several factors, including suboptimal quantization-aware training (QAT) configurations, inappropriate quantization granularity or scheme, sensitivity of certain model components to quantization, insufficient calibration data, or issues related to the implementation of batch normalization folding and quantization operations. Fine-tuning the QAT process, adjusting model architecture, or employing advanced quantization techniques may help mitigate the accuracy loss.

Thank you for the advice. So far I am handling these mentioned issues as follows:

  • Quantization granularity: I’ve tried per-axis and per-tensor quantization.
  • Calibration data: I am training and validating on the full semanticKitti dataset.
  • Quantization scheme: For now, I am using the default scheme with AllValuesQuantizer for activations and MovingAverageQuantizer for weights. I’ve seen in several publications that the chosen calibration method has a major impact on the performance, so I am experimenting with percentile calibration right now.
  • Sensitivity of certain model components: I’ve excluded layers from quantization to investigate that, but even deploying quantized convolutions only did not return the expected results.
  • BatchNorm: I’ve tried both folding the batch normalization weights into the convolution layer and unsetting their trainable property

Is there something I am doing obviously wrong or should invest more time in?