Passing a dictionary as input signature for tf.function

Hello everyone,

I am currently trying to create a tf.function for my tensorflow lite model that takes a dictionary of weights for different tensors and loads it into the lite interpreter. Currently it is possible to use my function by using a workaround where I hardcode the input signature for the different dictionary values, but I want to know if there is a better solution.

The dictionary I want to pass looks something like this (In my code i generated the dictionary, this here is just to demonstrate the form of the dictionary:

weights_dict = {
"conv1d/bias:0": [0., 0., 0., 0., 0.],
"conv1d/kernel:0": [[[-0.23031473,  0.25823927, -0.15676963, -0.27285957,  0.36472344]],
                    [[-0.15187979, -0.00110114,  0.05242634,  0.2745613,  -0.20043385]],
                    [[ 0.03774023 -0.45265818  0.24690723  0.23345459  0.15752888]],
                    [[-0.3444885  -0.07790518 -0.11909294 -0.34963953  0.14239442]]],
"dense/bias:0": [0. 0. 0. 0. 0. 0. 0. 0.],
"dense/kernel:0": [[ 0.10110682, -0.50597924,  0.3949796,   0.20653826, -0.26708043,  0.2091226, -0.5352592,  -0.6049661],
                    [-0.4578067,   0.29674542, -0.27303556, -0.568668,    0.4862703,   0.5257646, -0.28185177,  0.43704462], 
                    [-0.45960283, -0.44893104, -0.386113,   -0.18281126, -0.44102007, -0.21802643,  0.60579884,  0.23984993], 
                    [ 0.45038152,  0.5851245,   0.40392494, 0.28304034,  0.6369395,  -0.06007874,   -0.5181305,  -0.6644601],
                    [-0.42245385,  0.43564594,  0.45273066,  0.5162871,   0.12010175, -0.4153296, -0.00371552,  0.17521149]]
"dense_1/bias:0": [0.],
"dense_1/kernel:0": [[0.18899703], 
                     [0.54743993], 
                     [0.57986426],
                     [-0.31321746],
                     [0.37268388],
                     [0.18092150],
                     [-0.14566028],
                     [-0.26471186]]
}

This dictionary is passed into this tf.function:

@tf.function(input_signature=[signature_dict])
    def set_weights(self, weights):
        tf.print("im in")
        tensor_names = [weight.name for weight in self.model.weights]
        for i, tensor in enumerate(self.model.weights):
            tensor.assign(weights[tensor_names[i]])
        return tensor_names

The workaround i mentioned before is the signature_dict i pass into the input_signature. It is a hardcoded representation of the tensor shapes and looks like this:

signature_dict = { "conv1d/bias:0": tf.TensorSpec(shape=[5], dtype=tf.float32),
                  "conv1d/kernel:0": tf.TensorSpec(shape=[4, 1, 5], dtype=tf.float32),
                  "dense/bias:0": tf.TensorSpec(shape=[8], dtype=tf.float32),
                  "dense/kernel:0": tf.TensorSpec(shape=[5, 8], dtype=tf.float32),
                  "dense_1/bias:0": tf.TensorSpec(shape=[1], dtype=tf.float32),
                  "dense_1/kernel:0": tf.TensorSpec(shape=[8, 1], dtype=tf.float32) }

In order to invoke this function I need to do this:

set_weights = interpreter.get_signature_runner("set_weights")

weights = weights_dict["conv1d/bias:0"]
weights_1 = weights_dict["conv1d/kernel:0"]
weights_2 = weights_dict["dense/bias:0"]
weights_3 = weights_dict["dense/kernel:0"]
weights_4 = weights_dict["dense_1/bias:0"]
weights_5 = weights_dict["dense_1/kernel:0"]

output= set_weights(weights=weights, weights_1=weights_1, weights_2=weights_2, weights_3=weights_3, weights_4=weights_4, weights_5=weights_5)

As you can see, i have to do this weird workaround where I have to manually create the signature TensorSpecs and extract the different weight tensors from the dictionary before passing them using 6 different variables.

I would really like to simplify this, because it does not look good and I have to change everything if I want to change the model layers. In the best case, I want to pass the dictionary as a whole to the set_weights function and let it do the rest. This should like this:

output= set_weights(weights=weight_dict)

Does anybody know if there is a way to pass the dictionary as a whole into the tf.function and how? Thanks in advance!

Hi @Niklas_Kiefer ,

I have modified the code in def set_weights(self, weights) method.Please take a look at it and let me know if it is working for you.I haven’t tested the below code in any environment; I just followed the TensorFlow documentation and modified it.

@tf.function(input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)])
def set_weights(self, weights):
    tf.print("im in")
    for i, weight_name in enumerate(weights.keys()):
        weight = weights[weight_name]
        weight_shape = tf.TensorShape(weight.shape)
        tensor = self.model.get_layer(weight_name.split('/')[0]).get_weights()[i]
        tensor.assign(tf.reshape(weight, weight_shape))
    return list(weights.keys())

And then you can call the below code:

set_weights = interpreter.get_signature_runner(set_weights)
output = set_weights(weights_dict)

Please find the documentation for tf.TensorShape in the TensorFlow API reference documentation and set_weights documentation.

Many thanks.

HI @Laxma_Reddy_Patlolla ,

first of all, thank you for your participation in solving my problem!

I have tried to apply your changes in my code, but sadly the following exception occurs:

Traceback (most recent call last):
  File "/mnt/c/Users/nikla/Desktop/FedLF/server.py", line 11, in <module>
    tf.saved_model.save(
  File "/home/niklas/miniconda3/envs/tenv/lib/python3.10/site-packages/tensorflow/python/saved_model/save.py", line 1231, in save
    save_and_return_nodes(obj, export_dir, signatures, options)
  File "/home/niklas/miniconda3/envs/tenv/lib/python3.10/site-packages/tensorflow/python/saved_model/save.py", line 1267, in save_and_return_nodes
    _build_meta_graph(obj, signatures, options, meta_graph_def))
  File "/home/niklas/miniconda3/envs/tenv/lib/python3.10/site-packages/tensorflow/python/saved_model/save.py", line 1440, in _build_meta_graph
    return _build_meta_graph_impl(obj, signatures, options, meta_graph_def)
  File "/home/niklas/miniconda3/envs/tenv/lib/python3.10/site-packages/tensorflow/python/saved_model/save.py", line 1388, in _build_meta_graph_impl
    signature_serialization.validate_augmented_graph_view(augmented_graph_view)
  File "/home/niklas/miniconda3/envs/tenv/lib/python3.10/site-packages/tensorflow/python/saved_model/signature_serialization.py", line 304, in validate_augmented_graph_view
    for name, dep in augmented_graph_view.list_children(
  File "/home/niklas/miniconda3/envs/tenv/lib/python3.10/site-packages/tensorflow/python/saved_model/save.py", line 177, in list_children
    for name, child in super(_AugmentedGraphView, self).list_children(
  File "/home/niklas/miniconda3/envs/tenv/lib/python3.10/site-packages/tensorflow/python/checkpoint/graph_view.py", line 75, in list_children
    for name, ref in super(ObjectGraphView,
  File "/home/niklas/miniconda3/envs/tenv/lib/python3.10/site-packages/tensorflow/python/checkpoint/trackable_view.py", line 84, in children
    for name, ref in obj._trackable_children(save_type, **kwargs).items():
  File "/home/niklas/miniconda3/envs/tenv/lib/python3.10/site-packages/tensorflow/python/trackable/autotrackable.py", line 115, in _trackable_children
    fn._list_all_concrete_functions_for_serialization()  # pylint: disable=protected-access
  File "/home/niklas/miniconda3/envs/tenv/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 1141, in _list_all_concrete_functions_for_serialization
    concrete_functions = self._list_all_concrete_functions()
  File "/home/niklas/miniconda3/envs/tenv/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 1123, in _list_all_concrete_functions
    self.get_concrete_function()
  File "/home/niklas/miniconda3/envs/tenv/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 1215, in get_concrete_function
    concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
  File "/home/niklas/miniconda3/envs/tenv/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 1195, in _get_concrete_function_garbage_collected
    self._initialize(args, kwargs, add_initializers_to=initializers)
  File "/home/niklas/miniconda3/envs/tenv/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 749, in _initialize
    self._variable_creation_fn    # pylint: disable=protected-access
  File "/home/niklas/miniconda3/envs/tenv/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py", line 162, in _get_concrete_function_internal_garbage_collected
    concrete_function, _ = self._maybe_define_concrete_function(args, kwargs)
  File "/home/niklas/miniconda3/envs/tenv/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py", line 157, in _maybe_define_concrete_function
    return self._maybe_define_function(args, kwargs)
  File "/home/niklas/miniconda3/envs/tenv/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py", line 360, in _maybe_define_function
    concrete_function = self._create_concrete_function(args, kwargs)
  File "/home/niklas/miniconda3/envs/tenv/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py", line 284, in _create_concrete_function
    func_graph_module.func_graph_from_py_func(
  File "/home/niklas/miniconda3/envs/tenv/lib/python3.10/site-packages/tensorflow/python/framework/func_graph.py", line 1283, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/home/niklas/miniconda3/envs/tenv/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 645, in wrapped_fn
    out = weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/home/niklas/miniconda3/envs/tenv/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py", line 445, in bound_method_wrapper
    return wrapped_fn(*args, **kwargs)
  File "/home/niklas/miniconda3/envs/tenv/lib/python3.10/site-packages/tensorflow/python/framework/func_graph.py", line 1269, in autograph_handler
    raise e.ag_error_metadata.to_exception(e)
  File "/home/niklas/miniconda3/envs/tenv/lib/python3.10/site-packages/tensorflow/python/framework/func_graph.py", line 1258, in autograph_handler
    return autograph.converted_call(
  File "/home/niklas/miniconda3/envs/tenv/lib/python3.10/site-packages/tensorflow/python/autograph/impl/api.py", line 439, in converted_call
    result = converted_f(*effective_args, **kwargs)
  File "/tmp/__autograph_generated_fileid_be7hf.py", line 29, in tf__set_weights
    ag__.for_stmt(ag__.converted_call(ag__.ld(enumerate), (ag__.converted_call(ag__.ld(weights).keys, (), None, fscope),), None, fscope), None, loop_body, get_state, set_state, (), {'iterate_names': '(i, weight_name)'})
  File "/home/niklas/miniconda3/envs/tenv/lib/python3.10/site-packages/tensorflow/python/framework/ops.py", line 444, in __getattr__
    self.__getattribute__(name)
AttributeError: in user code:

    File "/mnt/c/Users/nikla/Desktop/FedLF/fl_model.py", line 78, in set_weights  *
        for i, weight_name in enumerate(weights.keys()):

    AttributeError: 'Tensor' object has no attribute 'keys'

It seems that the weights input parameter for the function is seen as a tensor object instead of a dictionary, therefore no keys() attribute is found.