Batch Inference on Android- Classification

Hi-

I have been researching and playing around with batch inference on the android and I can’t get it quite right. As a proof of concept, I ran my model(MobilNetV2-classification) on the Coral Devboard and it was able to handle batch inference just fine.

Now I am trying to run batch inference on an Android phone, and it is quite difficult.
I’m using the Android Classification Example as my starting point, and then choosing to use the lib_support_api because the lib_task_api specifically says that it does not support batch inference.

It took me awhile to get familiar with this code, but I eventually figured out that everything that needed changing to make batch inference happen would probably be in the ../lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/Classifier.java file.

My strategy for cropping and processing a 1920x1080 input stream:

  • Reshape the model from [1,224,224,3] → [16,224,224,3]

  • Input comes in as a bitmap

  • Break the center of that large bitmap into 16 [224,224,3] chunks using for loops and createBitmap

  • Use TensorImage.load() on each bitmap and store in an array ArrayList<TensorImage>

This TensorImage business is where I have run into my main problems.
I tried using the Interpreter.runForMultipleInputsOutputs(input,output) function but I realized that this is intended for a model which actually has multiple input tensors. So it expects something like this:
input{ 0:[1,224,224,3], 1:[1,224,224,3], 2:[1,224,224,3]... }

Whereas I have input like this

input{0:[16x224x224x3]}
I could try and train a model to take inputs like this but I don’t exactly know what that entails.

After learning the above information I went back to using the Interpreter.run(input,output) function, which presented issues of its own.

Specifically, the issue now is that the Interpreter.run(input,output) function expects something that it can digest. My ArrayList<TensorImage> input is not taken nicely and retorts that it does not understand the datatype TensorImage. This makes sense because the there is a check in the Interpreter code requiring that if the input is an array, its elements have to be one of these types:

  • uint8
  • int64
  • String
  • Bool

So now the question is, can I use a bitmap to create an ArrayList<uint8> such that all the information is still present?

There is a whole lot of janky business going on with batch inference and Tensorflow. Not just in the android/java libraries, but in the python one too. Maybe they have deemed this concept not worth their time, but currently it seems as if they only half support it.

Sooooooo @tensorflow, would you kindly support batch inference? Its OP.

  • Support Batch Inference
  • Don’t Support Batch Inference

0 voters

It looks like we just need to resize the input:

As I said in my post, resizing is not the issue. Its the way in which the images themselves are processed and fed into the network.

I see, so you’re wondering how to shuffle the data around? With JavaCV, we could do something like this:

AndroidFrameConverter bitmapConverter = new AndroidFrameConverter();
FloatIndexer tensorIndexer = FloatIndexer.create(tensorBuffer, 16, 224, 224, 3);
for (int i = 0; i < 16; i++) {
    UByteIndexer idx = bitmapConverter.convert(images[i]).createIndexer();
    for (int j = 0; j < 224; j++) {
        for (int k = 0; k < 224; k++) {
            tensorIndexer.put(i, j, k, 0, idx.get(j, k, 0));
            tensorIndexer.put(i, j, k, 1, idx.get(j, k, 1));
            tensorIndexer.put(i, j, k, 2, idx.get(j, k, 2));
        }
    }
}

It would be more efficient to index directly your large original image instead of creating multiple images though, see Bytedeco - Third release at Bytedeco

Nice library :grinning: . I’ll check it out and let you know if I can solve my issue by using it.

I’ve managed to solve the above issue. Here’s the code.

    private const val FRAME_W = 1440
    private const val FRAME_H = 1080
    const val TARGET_VIDEO_SIZE = 224
    const val NUM_TILES = 16
    const val COLORS = 3

    const val BITMAP_SIZE_BYTES = TARGET_VIDEO_SIZE * TARGET_VIDEO_SIZE * COLORS
    const val MODEL_INPUT_SIZE = NUM_TILES * BITMAP_SIZE_BYTES

// Define some function to analyze an image

//Input buffer---- Its critical that the model input size is correct----
            private var inTensorBuffer: FloatBuffer? = Tensor.allocateFloatBuffer(OBJConstants.MODEL_INPUT_SIZE)

            // Variables for the tflite interpreter
            val tflite : InterpreterApi

            // The output buffer where probabilities are stored
            val probabilityBuffer = TensorBuffer.createFixedSize(intArrayOf(1, 1001), DataType.UINT8)

                // This has type of MappedByteBuffer
                // For now, the model has to be of type Float32 because that is the DataType of the input buffer
                val tfliteModel = FileUtil.loadMappedFile(applicationContext,"models/efficientnetv2-b0_224.tflite")
                tflite = InterpreterFactory().create(tfliteModel, InterpreterApi.Options());

                // Resize model to accept the number of desired tiles
                val intArray = intArrayOf(OBJConstants.NUM_TILES, OBJConstants.TARGET_VIDEO_SIZE, OBJConstants.TARGET_VIDEO_SIZE, OBJConstants.COLORS)
                tflite.resizeInput(0, intArray)


           
            // Run inference and record processing time
            val startTime = SystemClock.elapsedRealtime()
            if(null != tflite) {
                tflite.run(inTensorBuffer, probabilityBuffer.getBuffer());
                Log.d(OBJConstants.TAG, tflite.lastNativeInferenceDurationNanoseconds.toString())
            }
            val inferenceTime = SystemClock.elapsedRealtime() - startTime


               

The main issues I faced here were improper sizing of the input buffer, and not paying close attention to the size(in bytes) of each bitmap I created. Caused a lot of buffer underflows. Sorry the above code snippet is shit. Think of it as pseudo code

However, even after finding this solution, inference time per image still scales in a linear trend. This is troubling because my main reason for figuring out batch processing is to bring down inference time when the number of images is high. I think it has something to do with the way the tflite api was implemented, but I could be wrong. I’m going to make a post about it.