Hi All,
I have a problem regarding the Interpreter in tflite,
I have started a personal mini project with the [model_personalization] (examples/lite/examples/model_personalization at master · tensorflow/examples · GitHub)
of @khanhlvg @lu-wang-g
example as my base for the project. The main idea is that i wanted to specify a base model and a Sequential head model for my transfer_learning_model.
And now i seem to have a problem at the load_bottlenecks function from the project.
The difference from the example and my project is that the images load from a dataset that’s placed in the assets directory.
This is how i load the images:
private void addSample(String photoPath, Boolean isTraining) throws IOException {
BitmapFactory.Options options = new BitmapFactory.Options();
options.inPreferredConfig = Bitmap.Config.ARGB_8888;
Bitmap bitmap = BitmapFactory.decodeStream(this.context.getAssets().open(photoPath), null, options);
String sampleClass = get_class(photoPath);
// get rgb equivalent and class
float[][][] rgbImage = prepareImage(bitmap);
// add to the list.
try {
this.tlModel.addSample(rgbImage, sampleClass, isTraining).get();
} catch (ExecutionException e) {
throw new RuntimeException("Failed to add sample to model", e.getCause());
} catch (InterruptedException e) {
// no-op
}
}
This is how the imaged is preprocessed :
private static float[][][] prepareImage(Bitmap bitmap) {
int modelImageSize = TransferLearningModelWrapper.IMAGE_SIZE;
float[][][] normalizedRgb = new float[modelImageSize][modelImageSize][3];
for (int y = 0; y < modelImageSize; y++) {
for (int x = 0; x < modelImageSize; x++) {
int rgb = bitmap.getPixel(x, y);
float r = ((rgb >> 16) & LOWER_BYTE_MASK) * (1 / 255.0f);
float g = ((rgb >> 8) & LOWER_BYTE_MASK) * (1 / 255.0f);
float b = (rgb & LOWER_BYTE_MASK) * (1 / 255.0f);
normalizedRgb[y][x][0] = r;
normalizedRgb[y][x][1] = g;
normalizedRgb[y][x][2] = b;
}
}
return normalizedRgb;
}
This is where the loadBottleneck method is called:
public Future<Void> addSample(float[][][] image, String className,Boolean isTraining) {
checkNotTerminating();
if (!classes.containsKey(className)) {
throw new IllegalArgumentException(String.format(
"Class \"%s\" is not one of the classes recognized by the model", className));
}
return executor.submit(
() -> {
if (Thread.interrupted()) {
return null;
}
trainingInferenceLock.lockInterruptibly();
try {
float[] bottleneck = model.loadBottleneck(image);
if (isTraining)
trainingSamples.add(new TrainingSample(bottleneck, oneHotEncodedClass.get(className)));
else
testingSamples.add(new TestingSample(image,className));
} finally {
trainingInferenceLock.unlock();
}
return null;
});
}
THIS IS THE LOAD_BOTTLENECK METHOD
float[] loadBottleneck(float[][][] image) {
Map<String, Object> inputs = new HashMap<>();
inputs.put("feature", new float[][][][]{image});
Map<String, Object> outputs = new HashMap<>();
float[][] bottleneck = new float[1][BOTTLENECK_SIZE];
outputs.put("bottleneck", bottleneck);
this.interpreter.runSignature(inputs, outputs, "load");
return bottleneck[0];
}
I have to specify that i debuged inside this method and the inputs and outputs seem to be created properly but when i call “this.interpreter.runSignature” i seem to gen an error inside the function runSignature function.
THIS IS THE ERROR I GET :
Caused by: java.lang.IllegalArgumentException: Input error: input feature not found.
By debugging even further in the Interpreter i found out that the function crashes here more specifically:
In the NativeInterpreterWrapper
from line
171 → 174
where this code is written
Entry input;
for(Iterator var14 = inputs.entrySet().iterator(); var14.hasNext(); inputsList[signatureRunnerWrapper.getInputIndex((String)input.getKey())] = input.getValue()) {
input = (Entry)var14.next();
}
For some reason somewhere this 4 lines that error pops out
the “inputs” variable it’s a Map<String, Object> inputs as it should be and the string is “feature” the name of the input variable from the python method load and the object is the image of 4 dimensions of [1][32][32][3].
I’m stuck on this for some time and i decided to call for some help
Also i have to speccify that i use tensorflow 2.7.0 to create the tflite model and i use
‘org.tensorflow:tensorflow-lite:2.7.0’ on the android app.