Extreme model accuracy loss due to TFLITE conversion w/ quantization

Hello together,

i currently work on training a object detection model using a ssd mobilenet v2 configuration in tensorflow 2.5.
This in general works ok with the training finishing around ~0.1 loss. Loading the model results in good detections with which i can work so far.

Problem is:
My current test cases all run on single images. In the final application this model is supposed to do the object detection live on incoming camera images. So the plan was to use a tpu and convert the trained model to tflite for the coral edgetpu compiler to be able to handle it.
For this to work the tflite model needs to be 8 bit integer quantized but the model becomes unusable after beeing converted to tflite with 8 bit quantization. If my understanding of my model configuration is correct i even used quantization aware training to reduce the model accuracy loss but it still just results in garbage detections.

I use the following configuration for an ssd mobilenet v2:

model {
  ssd {
    inplace_batchnorm_update: true
    freeze_batchnorm: false
    num_classes: 3
    box_coder {
      faster_rcnn_box_coder {
        y_scale: 10.0
        x_scale: 10.0
        height_scale: 5.0
        width_scale: 5.0
      }
    }
    matcher {
      argmax_matcher {
        matched_threshold: 0.5
        unmatched_threshold: 0.5
        ignore_thresholds: false
        negatives_lower_than_unmatched: true
        force_match_for_each_row: true
        use_matmul_gather: true
      }
    }
    similarity_calculator {
      iou_similarity {
      }
    }
    encode_background_as_zeros: true
    anchor_generator {
      ssd_anchor_generator {
        num_layers: 6
        min_scale: 0.2
        max_scale: 0.95
        aspect_ratios: 1.0
        aspect_ratios: 2.0
        aspect_ratios: 0.5
        aspect_ratios: 3.0
        aspect_ratios: 0.3333
      }
    }
    image_resizer {
      fixed_shape_resizer {
        height: 514
        width: 614
      }
    }
    box_predictor {
      convolutional_box_predictor {
        min_depth: 0
        max_depth: 0
        num_layers_before_predictor: 0
        use_dropout: false
        dropout_keep_probability: 0.8
        kernel_size: 1
        box_code_size: 4
        apply_sigmoid_to_scores: false
        class_prediction_bias_init: -4.6
        conv_hyperparams {
          activation: RELU_6,
          regularizer {
            l2_regularizer {
              weight: 0.00004
            }
          }
          initializer {
            random_normal_initializer {
              stddev: 0.01
              mean: 0.0
            }
          }
          batch_norm {
            train: true,
            scale: true,
            center: true,
            decay: 0.97,
            epsilon: 0.001,
          }
        }
      }
    }
    feature_extractor {
      type: 'ssd_mobilenet_v2_keras'
      min_depth: 16
      depth_multiplier: 1.0
      conv_hyperparams {
        activation: RELU_6,
        regularizer {
          l2_regularizer {
            weight: 0.00004
          }
        }
        initializer {
          truncated_normal_initializer {
            stddev: 0.03
            mean: 0.0
          }
        }
        batch_norm {
          train: true,
          scale: true,
          center: true,
          decay: 0.97,
          epsilon: 0.001,
        }
      }
      override_base_feature_extractor_hyperparams: true
    }
    loss {
      classification_loss {
        weighted_sigmoid_focal {
          alpha: 0.75,
          gamma: 2.0
        }
      }
      localization_loss {
        weighted_smooth_l1 {
          delta: 1.0
        }
      }
      classification_weight: 1.0
      localization_weight: 1.0
    }
    normalize_loss_by_num_matches: true
    normalize_loc_loss_by_codesize: true
    post_processing {
      batch_non_max_suppression {
        score_threshold: 1e-8
        iou_threshold: 0.6
        max_detections_per_class: 5
        max_total_detections: 15
      }
      score_converter: SIGMOID
    }
  }
}

train_config: {
  batch_size: 8
  sync_replicas: true
  startup_delay_steps: 0
  replicas_to_aggregate: 8
  num_steps: 40000
  data_augmentation_options {
    random_horizontal_flip {
    }
  }
  data_augmentation_options {
    random_crop_image {
      min_object_covered: 0.0
      min_aspect_ratio: 0.75
      max_aspect_ratio: 3.0
      min_area: 0.75
      max_area: 1.0
      overlap_thresh: 0.0
    }
  }
  optimizer {
    momentum_optimizer: {
      learning_rate: {
        cosine_decay_learning_rate {
          learning_rate_base: .13
          total_steps: 40000
          warmup_learning_rate: .026666
          warmup_steps: 1000
        }
      }
      momentum_optimizer_value: 0.9
    }
    use_moving_average: false
  }
  max_number_of_boxes: 100
  unpad_groundtruth_tensors: false
}

train_input_reader: {
  label_map_path: "...\\label_map.pbtxt"
  tf_record_input_reader {
    input_path: "...\\training.tfrecord"
  }
}

eval_config: {
  metrics_set: "coco_detection_metrics"
  use_moving_averages: false
}

eval_input_reader: {
  label_map_path: "...\\label_map.pbtxt"
  shuffle: false
  num_epochs: 1
  tf_record_input_reader {
    input_path: "...\\eval.tfrecord"
  }
}
graph_rewriter {
  quantization {
    delay: 1000
    weight_bits: 8
    activation_bits: 8
  }
}

I’d like to show some images in here but the forum wont let me. Will try to figure out a way to post some example images in here…
As a description of what you’d see: The tflite model creates like 5 detections with 50 cofidence and horrible bounding boxes (2-3x times too large and the center of the object is somewhere on the border line) sometimes even with the wrong label. Meanwhile the original model is always 100% spot on with at least 99% confidence and perfectly fitting bounding boxes (max 2-3 pixels off)
I am fully aware that a quantized model will never bring the same accuracy as a float model but the tensorflow documentation led me to believe that the accuracy loss should be somewhat around <3%.
While i do not know how the term “accuracy loss” is defined here i would not think that my results are something that can be described as a <3% accuracy loss.

For completion of my workflow: I convert the model from checkpoints to saved model format using the script export_tflite_graph_tf2.py from the object detection API. After this i use the following python script to load the model, convert it and save it:

def representative_dataset_gen():
    data = tf_training_helper.load_images_in_folder_to_numpy_array(IMAGE_BASE_PATH)
    (count, x,y,c) = data.shape
    for i in range(count):
        yield [data[i,:,:,:].reshape(1,x,y,c).astype(np.float32)]
        
input_data = tf_training_helper.load_image_into_numpy_array(IMAGE_PATHS)

converter = tf.lite.TFLiteConverter.from_saved_model(MODEL_PATH)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset_gen
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8
converter.target_spec.supported_types = [tf.int8]

tflite_model = converter.convert()
print("model conversion finished. Starting validation... ")
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

interpreter.set_tensor(input_details[0]['index'], input_data.astype(np.uint8))

interpreter.invoke()
out_boxes = interpreter.get_tensor(output_details[0]['index'])
out_classes = interpreter.get_tensor(output_details[1]['index'])
out_scores = interpreter.get_tensor(output_details[2]['index'])
with tf.io.gfile.GFile(MODEL_SAVE_PATH, 'wb') as f:
  f.write(tflite_model)

(i redacted some unecessary parts from it to make it a bit shorter. (e.g. path definitions))

If you have any idea on how this loss in accuracy can be prevented or even just why it is there please let me know. Any ideas or suggestions are welcome. (e.g. is ssd mobilenet the correct thing to go for? are other models more robust to conversion?)
Please also let me know if i missed to add any information, i’d be happy to provide it.

Have a great weekend,
Cheers
Georg

hello,

  1. how exactly you performed QAT? , did you fine tune the model or trained it from scratch? fine tune is better , are you quantizing all layers in the model? , there might be operations in the first layers that once quantized make the model generalization performance decrease.

  2. I think once you have done QAT, you don’t need a representative dataset to calibrate the range of the activations, just run the converter with converter.optimizations = [tf.lite.Optimize.DEFAULT] like it is done here in the docs: Quantization aware training comprehensive guide  |  TensorFlow Model Optimization

you can check the model layers details in the Netron app, I use it to debug and see if the layer inputs/outputs and weights make sense after quantization.

Hey,

thank you for your thoughts on this!

To 1: As far as i am aware QAT (assuming this means quantization aware training) is done by solely adding the following to the training config at the bottom:

graph_rewriter {
  quantization {
    delay: 1000
    weight_bits: 8
    activation_bits: 8
  }
}

Which should activate the quantization effects after the first 1000 steps of the training run. I do not know if or how i could tell if all layers are quantized by this as i was not able to find much information regarding this. (if you can point me to a good read on this topic i’d gladly take it) I train the model from scratch as my application is really narrow and other detections are not needed/unwanted. From what i can tell finetuning should have nothing to do with this no? in the end the model itself works perfect before tflite conversion.

To 2: Yes that was my thought as well but as far as i know the full quantization process needed for the coral edgetpu-compiler is only achived with this configuration:

converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8
converter.target_spec.supported_types = [tf.int8]

This however fails with an error demanding a representative dataset if none is provided.
Following the advice you gave however i ran a conversion test using only

converter.optimizations = [tf.lite.Optimize.DEFAULT]

instead. Sadly this leads to even worse results with the highest confidence beeing 21% with the wrong label and a bounding box that is not even close to the object.

Thank you for pointing me towards netron. The net setup seems valid to me, the inbetween tensor sizes could be valid as well. As for weights i am not able to see them in the tflite model for some reason. The only thing i can confirm from this is that the quantized tflite model from my original post indeed is quantized as all factors in all layers seem to be integers and the input and output each run through a (de)quantization block.

Do you think that the workflow that i use in general is not to be used to achive the desired result? I read some github issues that this in general should work.

Cheers

Apologies for bumping an old thread but wanted to ask that with recent improvements in the model, has the accuracy been improved and can it be reliably used on a edge microcontroller or even a small Linux processor?