Batch Inference with tflite

Batch inference’s main goal is to speed up inference per image when dealing with many images at once.

Say I have a large image(2560x1440) and I want to run it through my model which has an input size of 640x480. Historically, the large input image has been squished down to fit the 640x480 input size. While this works, its not exactly optimal for use cases that need to detect small objects because these objects get squished into even smaller pixel representations and can become impossible to detect.

Lets stick with the small objects use case.

One solution to increase the likelihood of detecting these small objects is to break the 2560x1440 image into smaller pieces, say 12 separate 640x480 images. Each of these images will fit nicely into the model, but there are 12 of them!

If we were to just place the inference code in a for loop and run it for each image, it would take 12 times as long as the original squished example and certainly cause latency issues.

Enter batch inference.

Instead of looping over each image and running it through inference by itself, we can modify the input tensor to accept batches of 12 images, and save time by only executing the inference call one time.

This^^ is how it should work. I’ve seen it first hand in this PyTorch example

When I go to do the same thing with tflite, I am greeted with a linearly scaling inference time for every image. There is no savings at all here.

So my question is, has anyone been able to implement batch processing using the tflite api and actually witness time savings during inference?


1 Like

Hi Isaac,

can you share also the code you tried?
are you using raw TFLite inference or using the Task Lib?

Here is the code.

I’m pretty sure this is raw tflite inference as its imported from import org.tensorflow.lite and not import

import org.tensorflow.lite.InterpreterApi
import org.tensorflow.lite.InterpreterFactory

// Takes in a large bitmap(1440x1080)
fun analyzeImage(image: Bitmap?){

        // Bitmap not null
        if (image != null) {
            // Input buffer for the model
            inTensorBuffer = Tensor.allocateFloatBuffer(OBJConstants.MODEL_INPUT_SIZE)

            // Rotate the image 90 degrees
            val matrix = Matrix()

            // Variables for the tflite interpreter
            val tflite : InterpreterApi

            // The output buffer where probabilities are stored
            val probabilityBuffer = TensorBuffer.createFixedSize(intArrayOf(1, 1001), DataType.UINT8)


                // This has type of MappedByteBuffer
                // For now, the model has to be of type Float32 because that is the DataType of the input buffer
                val tfliteModel = FileUtil.loadMappedFile(applicationContext,"models/efficientnetv2-b0_224.tflite")
                // val tfliteModel = FileUtil.loadMappedFile(applicationContext,"phil_mvn2_224_uint8.tflite")
                tflite = InterpreterFactory().create(tfliteModel, InterpreterApi.Options());

                // Resize model to accept the number of desired tiles
                val intArray = intArrayOf(OBJConstants.NUM_TILES, OBJConstants.TARGET_VIDEO_SIZE, OBJConstants.TARGET_VIDEO_SIZE, OBJConstants.COLORS)
                tflite.resizeInput(0, intArray)

            } catch (e:IOException){
                Log.e("tfliteSupport", "Error reading model", e);
                return PyTorClassify.AnalysisResult("Empty")

            // i and k are the upper right corner pixel position of each tile
            var count = 0
            run {
                var floatMap: Map<String, kotlin.Float>
                var i: Int = OBJConstants.H_CROP_START
                while (i < OBJConstants.endHCrop) {
                    var k: Int = OBJConstants.W_CROP_START
                    while (k < OBJConstants.endWCrop) {
                        val centerCroppedBitmap = Bitmap.createBitmap(
                            inTensorBuffer, // The input tensor where bitmap is being stored
                            OBJConstants.BITMAP_SIZE_BYTES * count // space between each bitmap
                        k += OBJConstants.TARGET_VIDEO_SIZE
                    i += OBJConstants.TARGET_VIDEO_SIZE

            // Run inference and record processing time
            val startTime = SystemClock.elapsedRealtime()
            if(null != tflite) {
      , probabilityBuffer.getBuffer());
                Log.d(OBJConstants.TAG, tflite.lastNativeInferenceDurationNanoseconds.toString())
            val inferenceTime = SystemClock.elapsedRealtime() - startTime

I’ve run some tests on the edgetpu and various machines, and it seems that savings is only visible when utilizing a gpu or tpu…not the cpu. I’m going to test this on an android phone with a compatible gpu and see if this trend is reflected there.