Hello everyone,
I am currently trying to create a federated learning application for an embedded system and in order to do that I need to directly load model weights form a dictionary instead of loading them from a checkpoint file. I am using a TFLite Model similar to the model shown in the On-Device-Training tutorial and I want to add an additional signature for loading the weights from a dictionary. Is there a way to do this or does anbody know a different way of achieving this?
Hi @Niklas_Kiefer ,
Here are my inputs:
Define a new input signature that takes a dictionary containing the model weights as an input as follows.
def set_weights_from_dict(weights_dict):
# Parse the dictionary and set the weights accordingly
Load the TFLite model l using the tf.lite.Interpreter()
.
interpreter = tf.lite.Interpreter(model_path='model.tflite')
interpreter.allocate_tensors()
Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
Add the additional signature to the TFLite model using the add_signature()
method of the interpreter object.
interpreter.add_signature(
inputs=input_details + [tf.TensorSpec(shape=[], dtype=tf.string)],
outputs=output_details,
name='set_weights_from_dict',
saved_model_dir=None,
concrete_function=set_weights_from_dict
)
You can set the weights from the dictionary using the invoke()
method of the interpreter object.
weights_dict = {'weight_1': ..., 'weight_2': ..., ...}
weights_dict_bytes = pickle.dumps(weights_dict)
interpreter.invoke(inputs=[weights_dict_bytes])
And please find the Signatures in TensorFlow Lite and TensorFlow Lite inference and inference_signature documentations for more detailed explanation.
Please let me know if it helps you.
Thanks.
Hi @ Laxma_Reddy_Patlolla,
first of all thank you for the quick reply. I appreciate you taking your time . Unfortunately it seems like the interpreter.add_signature(…) method does not exist. I tried it in code and searched on the internet as well as on the official TensorFlow API, but I can not seem to find the add_signature method. For clarification, I am using tensorflow version 2.8.2.
Could it be that the function is available in later versions of tensorflow?