TfLite Model is giving different output on Android app and in python for same inputs. Outputs using python are realistic but for java it remains same. Why the outputs are behaving this way in java file?

So I’m building a very simple model using tensorflow that gives x+1 as output (prediction). I’ll deploy this model on android application so I convert it to tflite format. For Building model

import tensorflow as tf
# Create a simple Keras model.      
x = [1,2,3,4,5,6,7,8,9,10]
y = [2,3,4,5,6,7,8,9,10,11]

model = tf.keras.models.Sequential([tf.keras.layers.Dense(units=1, input_shape=[1])])
model.compile(optimizer='sgd', loss='mean_squared_error')
model.fit(x, y, epochs=50)

path_file = 'saved_model/hello_world_tensorflow'
tf.saved_model.save(model, path_file)

import pathlib
# Convert the model.
converter = tf.lite.TFLiteConverter.from_saved_model(path_file)
tflite_model = converter.convert()
tflite_model_file = pathlib.Path('model1.tflite')
tflite_model_file.write_bytes(tflite_model)

Using model in Python code for getting output

import numpy as np
import tensorflow as tf

# Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path="model1.tflite")
interpreter.allocate_tensors()

# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Test model on random input data.
input_shape = input_details[0]['shape']
print(input_shape)
input_data = np.array([[3]], dtype=np.float32) # 3 is the input here
interpreter.set_tensor(input_details[0]['index'], input_data)

interpreter.invoke()

# The function `get_tensor()` returns a copy of the tensor data.
# Use `tensor()` in order to get a pointer to the tensor.
output_data = interpreter.get_tensor(output_details[0]['index'])
print(output_data,input_data)

Using model in Java Code (MainActivity.java File ) android

package ar.labs.androidml;

import androidx.appcompat.app.AppCompatActivity;

import android.os.Bundle;
import android.view.View;
import android.widget.Button;
import android.widget.EditText;
import android.widget.TextView;
import android.widget.Toast;

import org.tensorflow.lite.DataType;
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;

import java.nio.ByteBuffer;

import ar.labs.androidml.ml.Model1;

public class MainActivity extends AppCompatActivity {

    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);
        Button btn= findViewById(R.id.button);
        btn.setOnClickListener(new View.OnClickListener() {
            @Override
            public void onClick(View view) {
                try{
                    EditText inputEditText;

                    inputEditText = findViewById(R.id.editTextNumberDecimal);
                    Float data= Float.valueOf(inputEditText.getText().toString());
                    ByteBuffer byteBuffer= ByteBuffer.allocateDirect(1*4);
                    byteBuffer.putFloat(data);

                    Model1 model = Model1.newInstance(getApplicationContext());

                    // Creates inputs for reference.
                    TensorBuffer inputFeature0 = TensorBuffer.createFixedSize(new int[]{1, 1}, DataType.FLOAT32);
                    inputFeature0.loadBuffer(byteBuffer);

                    // Runs model inference and gets result.
                    Model1.Outputs outputs = model.process(inputFeature0);
                    TensorBuffer outputFeature0 = outputs.getOutputFeature0AsTensorBuffer();

                    // Releases model resources if no longer used.
                    TextView tv= findViewById(R.id.textView);
                    float[] data1=outputFeature0.getFloatArray();

                    tv.setText(outputFeature0.getDataType().toString());
                    tv.setText(String.valueOf(data1[0]));


                    model.close();

                }
                catch (Exception e)
                {
                    Toast.makeText(getApplicationContext(),"Issue...",Toast.LENGTH_LONG).show();
                }
            }
        });
    }
}

Python code:

  • Input → Output
  • 1-> 1.4467…
  • 2 → 2.5395…
    *2.1->2.6488…
    *2.11->2.6597
  • 3 → 3.6323…

Java Code

  • Input → Output
  • 1 → 0.3540…
  • 2 → 0.3540…
  • 2.1 → 2.967…E23
  • 2.11 → 0.39083…
  • 41 → 0.3540…

Why are outputs from Python code and Java code are so different for same input?
Why the outputs are behaving this way in java file like returning a constant value for most cases?
Please help me fix.

Please refer to the solution here