TensorFlow Lite Metadata

Hello here,

I have a segmentation model with a classification branch with the final model containing two outputs. I am trying to convert the model to TfLite while adding the metadata but I am getting this error.

ValueError: The number of output tensors (2) should match the number of output tensor metadata (1)

This is how I am adding the metadata

    def add_metadata(self, label_file_paths=[]):
        """
        adds metadata to the model to be used for inference on android and writes it to a file
        """
        print("INFO: Adding metadata")
        from tflite_support.metadata_writers import writer_utils, image_segmenter
        model_path = self.tflite_quant_model_name

        writer = image_segmenter.MetadataWriter.create_for_inference(
            writer_utils.load_file(model_path),
            [self.NORM_MEAN],
            [self.NORM_STD],
            label_file_paths=label_file_paths
        )

Is there a way I can pass the model outputs while adding the metadata?

Is it possible to create custom writers?

Hi Evans,

if you look into this script: examples/metadata_writer_for_image_classifier.py at master · tensorflow/examples · GitHub

on line 156, it’s adding the output metadata to a model before saving it.
What I’d try is to add something like [outpu_meta_1, output_meta_2]

hope it helps, since I didn’t try it myself I can’t guarantee it will work

1 Like

Thanks for the idea, I managed to write something but can’t wrap my head around this error

AttributeError: 'ClassificationTensorMd' object has no attribute 'Pack'

This is my implemetation

from typing import List, Optional, Type

from tensorflow_lite_support.metadata import metadata_schema_py_generated as _metadata_fb
from tensorflow_lite_support.metadata.python.metadata_writers import metadata_info
from tensorflow_lite_support.metadata.python.metadata_writers import metadata_writer
from tensorflow_lite_support.metadata.python.metadata_writers import writer_utils

_MODEL_NAME = "ImageSegmenter"
_MODEL_DESCRIPTION = ("Semantic image segmentation predicts whether each pixel "
                      "of an image is associated with a certain class.")
_INPUT_NAME = "image"
_INPUT_DESCRIPTION = "Input image to be segmented."
_OUTPUT_NAME = "segmentation_masks"
_CLASSIFICATION_OUTPUT_NAME = "probability"
_OUTPUT_DESCRIPTION = "Masks over the target objects with high accuracy."
_CLASSIFICATION_OUTPUT_DESCRIPTION = "Probabilities of the labels respectively."
# The output tensor is in the shape of [1, ImageHeight, ImageWidth, N], where N
# is the number of objects that the segmentation model can recognize. The output
# tensor is essentially a list of grayscale bitmaps, where each value is the
# probability of the corresponding pixel belonging to a certain object type.
# Therefore, the content dimension range of the output tensor is [1, 2].
_CONTENT_DIM_MIN = 1
_CONTENT_DIM_MAX = 2


def _create_segmentation_masks_metadata(
        masks_md: metadata_info.TensorMd) -> _metadata_fb.TensorMetadataT:
    """Creates the metadata for the segmentation masks tensor."""
    masks_metadata = masks_md.create_metadata()

    # Create tensor content information.
    content = _metadata_fb.ContentT()
    content.contentProperties = _metadata_fb.ImagePropertiesT()
    content.contentProperties.colorSpace = _metadata_fb.ColorSpaceType.GRAYSCALE
    content.contentPropertiesType = _metadata_fb.ContentProperties.ImageProperties
    # Add the content range. See
    # https://github.com/tensorflow/tflite-support/blob/ace5d3f3ce44c5f77c70284fa9c5a4e3f2f92abb/tensorflow_lite_support/metadata/metadata_schema.fbs#L285-L347
    dim_range = _metadata_fb.ValueRangeT()
    dim_range.min = _CONTENT_DIM_MIN
    dim_range.max = _CONTENT_DIM_MAX
    content.range = dim_range
    masks_metadata.content = content

    return masks_metadata


class MetadataWriter(metadata_writer.MetadataWriter):
    """Writes metadata into an image segmenter."""

    @classmethod
    def create_from_metadata_info(
            cls,
            model_buffer: bytearray,
            general_md: Optional[metadata_info.GeneralMd] = None,
            input_md: Optional[metadata_info.InputImageTensorMd] = None,
            output_md: Optional[List[Type[metadata_info.TensorMd]]] = None,
            ):
        """Creates MetadataWriter based on general/input/outputs information.

    Args:
      model_buffer: valid buffer of the model file.
      general_md: general information about the model.
      input_md: input image tensor information.
      output_md: output segmentation mask tensor information. This tensor is a
        multidimensional array of [1 x mask_height x mask_width x num_classes],
        where mask_width and mask_height are the dimensions of the segmentation
        masks produced by the model, and num_classes is the number of classes
        supported by the model.

    Returns:
      A MetadataWriter object.
    """

        if general_md is None:
            general_md = metadata_info.GeneralMd(
                name=_MODEL_NAME, description=_MODEL_DESCRIPTION)

        if input_md is None:
            input_md = metadata_info.InputImageTensorMd(
                name=_INPUT_NAME,
                description=_INPUT_DESCRIPTION,
                color_space_type=_metadata_fb.ColorSpaceType.RGB)

        segmentation_output_md, classification_output_md = output_md

        if segmentation_output_md is None:
            segmentation_output_md = metadata_info.TensorMd(
                name=_OUTPUT_NAME, description=_OUTPUT_DESCRIPTION)

        if segmentation_output_md.associated_files is None:
            segmentation_output_md.associated_files = []

        if classification_output_md is None:
            classification_output_md = metadata_info.ClassificationTensorMd(
                name=_CLASSIFICATION_OUTPUT_NAME, description=_CLASSIFICATION_OUTPUT_DESCRIPTION)

        if classification_output_md.associated_files is None:
            classification_output_md.associated_files = []

        return super().create_from_metadata(
            model_buffer,
            model_metadata=general_md.create_metadata(),
            input_metadata=[input_md.create_metadata()],
            output_metadata=[
                _create_segmentation_masks_metadata(segmentation_output_md),
                classification_output_md
            ],
            associated_files=[
                file.file_path for file in classification_output_md.associated_files
            ])

    @classmethod
    def create_for_inference(cls, model_buffer: bytearray,
                             input_norm_mean: List[float],
                             input_norm_std: List[float],
                             label_file_paths: List[str],
                             score_calibration_md: Optional[metadata_info.ScoreCalibrationMd] = None):
        """Creates mandatory metadata for TFLite Support inference.

    The parameters required in this method are mandatory when using TFLite
    Support features, such as Task library and Codegen tool (Android Studio ML
    Binding). Other metadata fields will be set to default. If other fields need
    to be filled, use the method `create_from_metadata_info` to edit them.

    Args:
      model_buffer: valid buffer of the model file.
      input_norm_mean: the mean value used in the input tensor normalization
        [1].
      input_norm_std: the std value used in the input tensor normalizarion [1].
      label_file_paths: paths to the label files [2] in the category tensor.
        Pass in an empty list If the model does not have any label file.
      score_calibration_md: information of the score calibration operation [3]
        in the classification tensor. Optional if the model does not use score
        calibration.
      [1]:
        https://www.tensorflow.org/lite/convert/metadata#normalization_and_quantization_parameters
      [2]:
        https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L108

    Returns:
      A MetadataWriter object.
    """
        input_md = metadata_info.InputImageTensorMd(
            name=_INPUT_NAME,
            description=_INPUT_DESCRIPTION,
            norm_mean=input_norm_mean,
            norm_std=input_norm_std,
            color_space_type=_metadata_fb.ColorSpaceType.RGB,
            tensor_type=writer_utils.get_input_tensor_types(model_buffer)[0])

        segmentation_output_md = metadata_info.TensorMd(
            name=_OUTPUT_NAME,
            description=_OUTPUT_DESCRIPTION,
            associated_files=None)

        classification_output_md = metadata_info.ClassificationTensorMd(
            name=_CLASSIFICATION_OUTPUT_NAME,
            description=_CLASSIFICATION_OUTPUT_DESCRIPTION,
            label_files=[
                metadata_info.LabelFileMd(file_path=file_path)
                for file_path in label_file_paths[1]
            ],
            tensor_type=writer_utils.get_output_tensor_types(model_buffer)[0],
            score_calibration_md=score_calibration_md)

        return cls.create_from_metadata_info(
            model_buffer, input_md=input_md, output_md=[segmentation_output_md, classification_output_md])

I managed to solve the above error and everything works great now.
This is the final solution

from typing import List, Optional

from tensorflow_lite_support.metadata import metadata_schema_py_generated as _metadata_fb
from tensorflow_lite_support.metadata.python.metadata_writers import metadata_info
from tensorflow_lite_support.metadata.python.metadata_writers import metadata_writer
from tensorflow_lite_support.metadata.python.metadata_writers import writer_utils

_MODEL_NAME = "ImageClassifier"
MODEL_DESCRIPTION = ("Identify the most prominent object in the image from a "
                     "known set of categories.")
INPUT_NAME = "image"
INPUT_DESCRIPTION = "Input image to be classified."
OUTPUT_NAME = "probability"
OUTPUT_DESCRIPTION = "Probabilities of the labels respectively."


class MetadataWriter(metadata_writer.MetadataWriter):
    """Writes metadata into an image classifier."""

    @classmethod
    def create_from_metadata_info(
            cls,
            model_buffer: bytearray,
            general_md: Optional[metadata_info.GeneralMd] = None,
            input_md: Optional[metadata_info.InputImageTensorMd] = None,
            output_md: Optional[List[metadata_info.ClassificationTensorMd]] = None):
        """Creates MetadataWriter based on general/input/output information.

    Args:
      model_buffer: valid buffer of the model file.
      general_md: general information about the model. If not specified, default
        general metadata will be generated.
      input_md: input image tensor informaton, if not specified, default input
        metadata will be generated.
      output_md: output classification tensor informaton, if not specified,
        default output metadata will be generated.

    Returns:
      A MetadataWriter object.
    """

        if general_md is None:
            general_md = metadata_info.GeneralMd(
                name=_MODEL_NAME, description=MODEL_DESCRIPTION)

        if input_md is None:
            input_md = metadata_info.InputImageTensorMd(
                name=INPUT_NAME,
                description=INPUT_DESCRIPTION,
                color_space_type=_metadata_fb.ColorSpaceType.RGB)

        segmentation_output_md, classification_output_md = output_md

        if classification_output_md is None:
            classification_output_md = metadata_info.ClassificationTensorMd(
                name=OUTPUT_NAME, description=OUTPUT_DESCRIPTION)

        if classification_output_md.associated_files is None:
            classification_output_md.associated_files = []



        return super().create_from_metadata_info(
            model_buffer=model_buffer,
            general_md=general_md,
            input_md=[input_md],
            output_md=[segmentation_output_md, classification_output_md],
            associated_files=[
                file.file_path for file in classification_output_md.associated_files
            ])

    @classmethod
    def create_for_inference(
            cls,
            model_buffer: bytearray,
            input_norm_mean: List[float],
            input_norm_std: List[float],
            label_file_paths: List[str],
            score_calibration_md: Optional[metadata_info.ScoreCalibrationMd] = None):
        """Creates mandatory metadata for TFLite Support inference.

    The parameters required in this method are mandatory when using TFLite
    Support features, such as Task library and Codegen tool (Android Studio ML
    Binding). Other metadata fields will be set to default. If other fields need
    to be filled, use the method `create_from_metadata_info` to edit them.

    Args:
      model_buffer: valid buffer of the model file.
      input_norm_mean: the mean value used in the input tensor normalization
        [1].
      input_norm_std: the std value used in the input tensor normalizarion [1].
      label_file_paths: paths to the label files [2] in the classification
        tensor. Pass in an empty list if the model does not have any label file.
      score_calibration_md: information of the score calibration operation [3]
        in the classification tensor. Optional if the model does not use score
        calibration.
      [1]:
        https://www.tensorflow.org/lite/convert/metadata#normalization_and_quantization_parameters
      [2]:
        https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L95
      [3]:
        https://github.com/tensorflow/tflite-support/blob/5e0cdf5460788c481f5cd18aab8728ec36cf9733/tensorflow_lite_support/metadata/metadata_schema.fbs#L434

    Returns:
      A MetadataWriter object.
    """
        input_md = metadata_info.InputImageTensorMd(
            name=INPUT_NAME,
            description=INPUT_DESCRIPTION,
            norm_mean=input_norm_mean,
            norm_std=input_norm_std,
            color_space_type=_metadata_fb.ColorSpaceType.RGB,
            tensor_type=writer_utils.get_input_tensor_types(model_buffer)[0])

        classification_output_md = metadata_info.ClassificationTensorMd(
            name=OUTPUT_NAME,
            description=OUTPUT_DESCRIPTION,
            label_files=[
                metadata_info.LabelFileMd(file_path=file_path)
                for file_path in label_file_paths
            ],
            tensor_type=writer_utils.get_output_tensor_types(model_buffer)[0],
            score_calibration_md=score_calibration_md)

        segmentation_output_md = metadata_info.TensorMd(
            name=OUTPUT_NAME,
            description=OUTPUT_DESCRIPTION,
            associated_files=None)

        return cls.create_from_metadata_info(
            model_buffer, input_md=input_md, output_md=[segmentation_output_md, classification_output_md])
1 Like

Hello @Evans_Kiplagat ,I have a segmentation model with a classification branch with the final model containing two outputs,too.I noticed that you have no instance application here, can you give a reference parameter for an instance