Loading model weights in tensorflow-lite directly from a signature

Hello everyone,

I am currently trying to create a federated learning application for an embedded system and in order to do that I need to directly load model weights form a dictionary instead of loading them from a checkpoint file. I am using a TFLite Model similar to the model shown in the On-Device-Training tutorial and I want to add an additional signature for loading the weights from a dictionary. Is there a way to do this or does anbody know a different way of achieving this?

1 Like

Hi @Niklas_Kiefer ,

Here are my inputs:

Define a new input signature that takes a dictionary containing the model weights as an input as follows.

def set_weights_from_dict(weights_dict):
    # Parse the dictionary and set the weights accordingly

Load the TFLite model l using the tf.lite.Interpreter().

interpreter = tf.lite.Interpreter(model_path='model.tflite')

Get input and output tensors.

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

Add the additional signature to the TFLite model using the add_signature() method of the interpreter object.

    inputs=input_details + [tf.TensorSpec(shape=[], dtype=tf.string)],

You can set the weights from the dictionary using the invoke() method of the interpreter object.

weights_dict = {'weight_1': ..., 'weight_2': ..., ...}
weights_dict_bytes = pickle.dumps(weights_dict)


And please find the Signatures in TensorFlow Lite and TensorFlow Lite inference and inference_signature documentations for more detailed explanation.

Please let me know if it helps you.


Hi @ Laxma_Reddy_Patlolla,

first of all thank you for the quick reply. I appreciate you taking your time . Unfortunately it seems like the interpreter.add_signature(…) method does not exist. I tried it in code and searched on the internet as well as on the official TensorFlow API, but I can not seem to find the add_signature method. For clarification, I am using tensorflow version 2.8.2.

Could it be that the function is available in later versions of tensorflow?