Tensorflow lite Interpreter error

Hi All,

I have a problem regarding the Interpreter in tflite,
I have started a personal mini project with the [model_personalization] (examples/lite/examples/model_personalization at master · tensorflow/examples · GitHub)
of @khanhlvg @lu-wang-g
example as my base for the project. The main idea is that i wanted to specify a base model and a Sequential head model for my transfer_learning_model.
And now i seem to have a problem at the load_bottlenecks function from the project.
The difference from the example and my project is that the images load from a dataset that’s placed in the assets directory.

This is how i load the images:

    private void addSample(String photoPath, Boolean isTraining) throws IOException {
        BitmapFactory.Options options = new BitmapFactory.Options();
        options.inPreferredConfig = Bitmap.Config.ARGB_8888;
        Bitmap bitmap =  BitmapFactory.decodeStream(this.context.getAssets().open(photoPath), null, options);
        String sampleClass = get_class(photoPath);

        // get rgb equivalent and class
        float[][][] rgbImage = prepareImage(bitmap);

        // add to the list.
        try {
            this.tlModel.addSample(rgbImage, sampleClass, isTraining).get();
        } catch (ExecutionException e) {
            throw new RuntimeException("Failed to add sample to model", e.getCause());
        } catch (InterruptedException e) {
            // no-op
        }
    }

This is how the imaged is preprocessed :

    private static float[][][] prepareImage(Bitmap bitmap)  {
        int modelImageSize = TransferLearningModelWrapper.IMAGE_SIZE;

        float[][][] normalizedRgb = new float[modelImageSize][modelImageSize][3];

        for (int y = 0; y < modelImageSize; y++) {
            for (int x = 0; x < modelImageSize; x++) {
                int rgb = bitmap.getPixel(x, y);

                float r = ((rgb >> 16) & LOWER_BYTE_MASK) * (1 / 255.0f);
                float g = ((rgb >> 8) & LOWER_BYTE_MASK) * (1 / 255.0f);
                float b = (rgb & LOWER_BYTE_MASK) * (1 / 255.0f);

                normalizedRgb[y][x][0] = r;
                normalizedRgb[y][x][1] = g;
                normalizedRgb[y][x][2] = b;
            }
        }

        return normalizedRgb;
    }

This is where the loadBottleneck method is called:

   public Future<Void> addSample(float[][][] image, String className,Boolean isTraining) {
    checkNotTerminating();

    if (!classes.containsKey(className)) {
      throw new IllegalArgumentException(String.format(
          "Class \"%s\" is not one of the classes recognized by the model", className));
    }

    return executor.submit(
        () -> {
          if (Thread.interrupted()) {
            return null;
          }

          trainingInferenceLock.lockInterruptibly();
          try {
            float[] bottleneck = model.loadBottleneck(image);
            if (isTraining)
                trainingSamples.add(new TrainingSample(bottleneck, oneHotEncodedClass.get(className)));
            else
                testingSamples.add(new TestingSample(image,className)); 
          } finally {
            trainingInferenceLock.unlock();
          }

          return null;
        });
  }

THIS IS THE LOAD_BOTTLENECK METHOD

  float[] loadBottleneck(float[][][] image) {

    Map<String, Object> inputs = new HashMap<>();
    inputs.put("feature", new float[][][][]{image});
    Map<String, Object> outputs = new HashMap<>();
    float[][] bottleneck = new float[1][BOTTLENECK_SIZE];
    outputs.put("bottleneck", bottleneck);
    this.interpreter.runSignature(inputs, outputs, "load");
    return bottleneck[0];
  }

I have to specify that i debuged inside this method and the inputs and outputs seem to be created properly but when i call “this.interpreter.runSignature” i seem to gen an error inside the function runSignature function.

THIS IS THE ERROR I GET :

     Caused by: java.lang.IllegalArgumentException: Input error: input feature not found.

By debugging even further in the Interpreter i found out that the function crashes here more specifically:

In the NativeInterpreterWrapper

from line
171 → 174

where this code is written

                    Entry input;
                    for(Iterator var14 = inputs.entrySet().iterator(); var14.hasNext(); inputsList[signatureRunnerWrapper.getInputIndex((String)input.getKey())] = input.getValue()) {
                        input = (Entry)var14.next();
                    }

For some reason somewhere this 4 lines that error pops out
the “inputs” variable it’s a Map<String, Object> inputs as it should be and the string is “feature” the name of the input variable from the python method load and the object is the image of 4 dimensions of [1][32][32][3].

I’m stuck on this for some time and i decided to call for some help :slight_smile:
Also i have to speccify that i use tensorflow 2.7.0 to create the tflite model and i use
‘org.tensorflow:tensorflow-lite:2.7.0’ on the android app.

4 Likes