On-device training for fashion mnist

Hello everyone,

I am trying to implement On-device training for fashion mnist model. But i am facing some run time issues.

I am really confused about this. Can anyone please assist me :pray:. I am not getting any idea how to implement this.


But i am facing some run time issues.

Could you please share stack trace?

Thank you!

Hi @chunduriv ,

Sorry for the late reply. here is the response i am getting when ever i click on train button

2023-05-11 12:31:39.615  8542-8542  MsyncFactory            com.application.mlmodeltesting       E  [static] ClassNotFoundException
                                                                                                    java.lang.ClassNotFoundException: com.mediatek.view.impl.MsyncFactoryImpl
                                                                                                    	at java.lang.Class.classForName(Native Method)
                                                                                                    	at java.lang.Class.forName(Class.java:454)
                                                                                                    	at java.lang.Class.forName(Class.java:379)
                                                                                                    	at com.mediatek.view.MsyncFactory.<clinit>(MsyncFactory.java:14)
                                                                                                    	at com.mediatek.view.MsyncFactory.getInstance(MsyncFactory.java:29)
                                                                                                    	at android.view.ViewRootImpl.<init>(ViewRootImpl.java:763)
                                                                                                    	at android.view.ViewRootImpl.<init>(ViewRootImpl.java:859)
                                                                                                    	at android.view.WindowManagerGlobal.addView(WindowManagerGlobal.java:393)
                                                                                                    	at android.view.WindowManagerImpl.addView(WindowManagerImpl.java:134)
                                                                                                    	at android.app.ActivityThread.handleResumeActivity(ActivityThread.java:5012)
                                                                                                    	at android.app.servertransaction.ResumeActivityItem.execute(ResumeActivityItem.java:54)
                                                                                                    	at android.app.servertransaction.ActivityTransactionItem.execute(ActivityTransactionItem.java:45)
                                                                                                    	at android.app.servertransaction.TransactionExecutor.executeLifecycleState(TransactionExecutor.java:176)

I recently changed the code for digit classification model

here is the code , i changed:

# Import the necessary libraries
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt

# Load the MNIST dataset
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# Normalize the pixel values
x_train = x_train / 255.0
x_test = x_test / 255.0

# Define a simple sequential model
def create_model():
  model = keras.Sequential([
    layers.Flatten(input_shape=(28, 28)),
    layers.Dense(128, activation='relu'),

  return model

# Create a model instance
model = create_model()

# Train the model on the training data
model.fit(x_train, y_train, epochs=5)

# Evaluate the model on the test data
model.evaluate(x_test, y_test, verbose=2)

# Define a function to add signatures for on-device training
def add_signatures(model):
  # Get the input and output tensors of the model
  input_tensor = model.input
  output_tensor = model.output

  # Define a train function that takes the input tensor and updates the weights
  @tf.function(input_signature=[tf.TensorSpec(input_tensor.shape, input_tensor.dtype)])
  def train(input_data):
    with tf.GradientTape() as tape:
      predictions = model(input_data)
      loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)(y_train[:len(input_data)], predictions)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer = tf.keras.optimizers.Adam()
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

  # Define an infer function that takes the input tensor and returns the output tensor
  @tf.function(input_signature=[tf.TensorSpec(input_tensor.shape, input_tensor.dtype)])
  def infer(input_data):
    return output_tensor

  # Define a save function that saves the weights to a file path
  @tf.function(input_signature=[tf.TensorSpec([], tf.string)])
  def save(file_path):
    tf.io.write_file(file_path, tf.io.serialize_tensor(model.get_weights()))

  # Define a restore function that restores the weights from a file path
  @tf.function(input_signature=[tf.TensorSpec([], tf.string)])
  def restore(file_path):
    model.set_weights(tf.io.parse_tensor(tf.io.read_file(file_path), out_type=tf.float32))

  # Return a dictionary of signatures
  signatures = {
      'train': train,
      'infer': infer,
      'save': save,
      'restore': restore,

  return signatures

# Convert the model to TensorFlow Lite format with signatures
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS,
converter.experimental_enable_resource_variables = True
converter._experimental_new_converter = True
signatures = add_signatures(model)
converter._experimental_signature_defs = signatures
tflite_model = converter.convert()

# Save the TensorFlow Lite model to a file
with open('digit_model.tflite', 'wb') as f:

# Download the TensorFlow Lite model file from Colab
from google.colab import files

The Activity class program in java class :

package com.application.mlmodeltesting;

import android.content.Intent;
import android.content.res.AssetFileDescriptor;
import android.graphics.Bitmap;
import android.net.Uri;
import android.os.Bundle;
import android.provider.MediaStore;
import android.view.View;
import android.widget.Button;
import android.widget.ImageView;
import android.widget.TextView;
import android.widget.Toast;

import androidx.appcompat.app.AppCompatActivity;

import org.tensorflow.lite.DataType;
import org.tensorflow.lite.Interpreter;
import org.tensorflow.lite.support.image.TensorImage;
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;

import java.io.FileInputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;

public class MainActivity extends AppCompatActivity {
    private Interpreter interpreter;
    private ImageView imageView;
    private TextView predictionTextView;
    private TextView accuracyTextView;
    private Button chooseImageButton;
    private Button classifyButton;
    private Bitmap imagebitmap;

    private ByteBuffer inputBuffer;

    private float[][] outputBuffer;

    private static final int NUM_CLASSES = 10;

    private static final int PICK_IMAGE = 1;
    private static final int IMAGE_SIZE = 28;

    private static final String MODEL_FILE = "digit_model.tflite";
    private static final String TRAIN_SIGNATURE = "train";
    private static final String INFER_SIGNATURE = "infer";
    private static final String SAVE_SIGNATURE = "save";
    private static final String RESTORE_SIGNATURE = "restore";

    protected void onCreate(Bundle savedInstanceState) {

        // Initialize TensorFlow Lite model
        try {
            interpreter = new Interpreter(loadModelFile());
            inputBuffer = ByteBuffer.allocateDirect(IMAGE_SIZE * IMAGE_SIZE * 4);
            outputBuffer = new float[1][NUM_CLASSES];
        } catch (IOException e) {

        // Get references to views
        imageView = findViewById(R.id.image_view);
        predictionTextView = findViewById(R.id.prediction_text_view);
        accuracyTextView = findViewById(R.id.accuracy_text_view);
        chooseImageButton = findViewById(R.id.choose_image_button);
        classifyButton = findViewById(R.id.classify_button);

        // Set up choose image button
        chooseImageButton.setOnClickListener(new View.OnClickListener() {
            public void onClick(View v) {
                // Open image picker
                // ...
                Intent intent = new Intent();
                startActivityForResult(Intent.createChooser(intent, "Select Picture"), PICK_IMAGE);

        classifyButton.setOnClickListener(new View.OnClickListener() {
            public void onClick(View v) {
                // Get the image from the imageView and convert it to a byte buffer

                // Create a map for the input tensor
                Map<String, Object> inputMap = new HashMap<>();
                inputMap.put("input", inputBuffer);

                // Create a map for the output tensor
                Map<String, Object> outputMap = new HashMap<>();
                outputMap.put("output", outputBuffer);

                // Run inference using the infer signature
                interpreter.runSignature(inputMap, outputMap, INFER_SIGNATURE);

                // Restore the model weights from the file system
                interpreter.runSignature(null, null, RESTORE_SIGNATURE);

                // Find the index of the maximum value in the output buffer
                int maxIndex = 0;
                float maxValue = 0f;
                for (int i = 0; i < NUM_CLASSES; i++) {
                    if (outputBuffer[0][i] > maxValue) {
                        maxIndex = i;
                        maxValue = outputBuffer[0][i];

                // Show the predicted digit and the confidence score
                predictionTextView.setText("Predicted digit: " + maxIndex + "\nConfidence: " + maxValue);


        // Set up classify button
//        classifyButton.setOnClickListener(new View.OnClickListener() {
//            @Override
//            public void onClick(View v) {
//                if (imagebitmap != null) {
//                    // Preprocess the image
//                    float[] input = getPreprocessedImage(imagebitmap);
//                    // Classify the image
//                    float[][] output = new float[1][10];
//                    Map<String, Object> inputs = new HashMap<>();
//                    inputs.put("x", input);
//                    inputs.put("y", output);
//                    Map<String, Object> outputs = new HashMap<>();
//                    FloatBuffer loss = FloatBuffer.allocate(1);
//                    outputs.put("loss", loss);
////                    tflite.runSignature(inputs, outputs, "train");
//                    tflite.run(input, output);
//                    // Find the class with the highest confidence
//                    int classIndex = 0;
//                    float maxConfidence = output[0][0];
//                    for (int i = 1; i < 10; i++) {
//                        if (output[0][i] > maxConfidence) {
//                            classIndex = i;
//                            maxConfidence = output[0][i];
//                        }
//                    }
//                    // Display the prediction
//                    predictionTextView.setText(String.valueOf(classIndex));
//                    accuracyTextView.setText(String.valueOf(maxConfidence));
//                }
//            }
//        });

    protected void onActivityResult(int requestCode, int resultCode, Intent data) {
        super.onActivityResult(requestCode, resultCode, data);

        if (requestCode == PICK_IMAGE && resultCode == RESULT_OK && data != null && data.getData() != null) {
            // Get the image URI
            Uri imageUri = data.getData();

            try {
                // Convert the image URI to a Bitmap
                imagebitmap = MediaStore.Images.Media.getBitmap(getContentResolver(), imageUri);

                // Display the image

                imagebitmap = getPreprocessedImage(imagebitmap);
            } catch (IOException e) {

    private Bitmap getPreprocessedImage(Bitmap image) {
        // Resize the image
        Bitmap resizedImage = Bitmap.createScaledBitmap(image, 28, 28, true);

        // Convert the image to a float array
        int width = resizedImage.getWidth();
        int height = resizedImage.getHeight();
        int[] pixels = new int[width * height];
        resizedImage.getPixels(pixels, 0, width, 0, 0, width, height);

        float[] imageData = new float[pixels.length];
        for (int i = 0; i < pixels.length; i++) {
            imageData[i] = (pixels[i] & 0xff) / 255.0f;

        return resizedImage;

    private MappedByteBuffer loadModelFile() throws IOException {
        // Open the model file from the assets folder
        AssetFileDescriptor fileDescriptor = getAssets().openFd(MODEL_FILE);
        FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
        FileChannel fileChannel = inputStream.getChannel();
        return fileChannel.map(FileChannel.MapMode.READ_ONLY, fileDescriptor.getStartOffset(), fileDescriptor.getDeclaredLength());

    private float[] inferImage(Bitmap bitmap) {
        // Pass the preprocessed image to the TensorFlow Lite model for inference
        TensorImage tensorImage = new TensorImage(DataType.FLOAT32);
        ByteBuffer byteBuffer = tensorImage.getBuffer();
        TensorBuffer inputBuffer = TensorBuffer.createFixedSize(new int[]{1, IMAGE_SIZE, IMAGE_SIZE, 1}, DataType.FLOAT32);

        TensorBuffer outputBuffer = TensorBuffer.createFixedSize(new int[]{1, 10}, DataType.FLOAT32);
        interpreter.run(inputBuffer.getBuffer(), outputBuffer.getBuffer());

        return outputBuffer.getFloatArray();


The error i am getting is :

2023-05-11 12:32:03.085  8542-8542  AndroidRuntime          com.application.mlmodeltesting       E  FATAL EXCEPTION: main
                                                                                                    Process: com.application.mlmodeltesting, PID: 8542
                                                                                                    java.lang.IllegalArgumentException: Input error: Signature infer not found.
                                                                                                    	at org.tensorflow.lite.NativeSignatureRunnerWrapper.<init>(NativeSignatureRunnerWrapper.java:28)
                                                                                                    	at org.tensorflow.lite.NativeInterpreterWrapper.getSignatureRunnerWrapper(NativeInterpreterWrapper.java:543)
                                                                                                    	at org.tensorflow.lite.NativeInterpreterWrapper.runSignature(NativeInterpreterWrapper.java:181)
                                                                                                    	at org.tensorflow.lite.Interpreter.runSignature(Interpreter.java:253)
                                                                                                    	at com.application.mlmodeltesting.MainActivity$2.onClick(MainActivity.java:110)
                                                                                                    	at android.view.View.performClick(View.java:7751)
                                                                                                    	at com.google.android.material.button.MaterialButton.performClick(MaterialButton.java:1219)
                                                                                                    	at android.view.View.performClickInternal(View.java:7724)
                                                                                                    	at android.view.View.access$3700(View.java:858)
                                                                                                    	at android.view.View$PerformClick.run(View.java:29336)
                                                                                                    	at android.os.Handler.handleCallback(Handler.java:938)
                                                                                                    	at android.os.Handler.dispatchMessage(Handler.java:99)
                                                                                                    	at android.os.Looper.loopOnce(Looper.java:210)
                                                                                                    	at android.os.Looper.loop(Looper.java:299)
                                                                                                    	at android.app.ActivityThread.main(ActivityThread.java:8280)
                                                                                                    	at java.lang.reflect.Method.invoke(Native Method)
                                                                                                    	at com.android.internal.os.RuntimeInit$MethodAndArgsCaller.run(RuntimeInit.java:576)
                                                                                                    	at com.android.internal.os.ZygoteInit.main(ZygoteInit.java:1073)

@chunduriv can you help me figure out , what is the mistake ? :slight_smile: