Fail to extract concrete function from tf.saved_model with multiple input nodes

Hello,

I am currently attempting to save model signatures with a static batch size in order to compile the network “ahead-of-time”. To achieve this, I am creating a concrete_function using a static tensorspec. Im doing this with Tensorflow 2.8.0.

This conversion process works well for models with only one input node or for Keras models with multiple input nodes. However, when attempting to extract the concrete_function of restored models that were saved and loaded with tf.saved_model, I encounter the following error, which I would not expect:

ValueError: Could not find matching concrete function to call loaded from the SavedModel. Got:
  Positional arguments (3 total):
    * [<tf.Tensor 'inputs:0' shape=(1, 2) dtype=float32>, <tf.Tensor 'inputs_1:0' shape=(1, 3) dtype=float32>, <tf.Tensor 'inputs_2:0' shape=(1, 10) dtype=float32>]
    * False
    * None
  Keyword arguments: {}

 Expected these arguments to match one of the following 4 option(s):

Option 1:
  Positional arguments (3 total):
    * (TensorSpec(shape=(None, 2), dtype=tf.float32, name='first'), TensorSpec(shape=(None, 3), dtype=tf.float32, name='second'), TensorSpec(shape=(None, 10), dtype=tf.float32, name='third'))
    * False
    * None
  Keyword arguments: {}

Option 2:
  Positional arguments (3 total):
    * (TensorSpec(shape=(None, 2), dtype=tf.float32, name='inputs/0'), TensorSpec(shape=(None, 3), dtype=tf.float32, name='inputs/1'), TensorSpec(shape=(None, 10), dtype=tf.float32, name='inputs/2'))
    * False
    * None
  Keyword arguments: {}

Option 3:
  Positional arguments (3 total):
    * (TensorSpec(shape=(None, 2), dtype=tf.float32, name='inputs/0'), TensorSpec(shape=(None, 3), dtype=tf.float32, name='inputs/1'), TensorSpec(shape=(None, 10), dtype=tf.float32, name='inputs/2'))
    * True
    * None
  Keyword arguments: {}

Option 4:
  Positional arguments (3 total):
    * (TensorSpec(shape=(None, 2), dtype=tf.float32, name='first'), TensorSpec(shape=(None, 3), dtype=tf.float32, name='second'), TensorSpec(shape=(None, 10), dtype=tf.float32, name='third'))
    * True
    * None

So it seems that a matching function could not be found. It appears that I am doing something wrong. Additionally, I do not understand why my input tensorspecs lose their name information when passed to get_concrete_function.

My question now is: How do I correctly extract the concrete function for my saved_model?
I would expect that my code would work, since it worked on keras models.

To reproduce this behavior, please use the following code example:

import tensorflow as tf

def create_test_model():
    # model with 2 input nodes and 1 output node, with non-static batchsize
    x1 = tf.keras.Input(shape=(2,), name="first")
    x2 = tf.keras.Input(shape=(3,), name="second")
    x3 = tf.keras.Input(shape=(10,), name="third")

    x = tf.concat([x1, x2], axis=1)
    a1 = tf.keras.layers.Dense(10, activation="elu")(x)
    y = tf.keras.layers.Dense(5, activation="softmax")(a1)
    model = tf.keras.Model(inputs=(x1, x2, x3), outputs=y)
    return model


def static_concrete_function(model, batch_size: int):
    static_tensorspec = [tf.TensorSpec(shape=(batch_size, 2), dtype=tf.float32, name='first'),
                    tf.TensorSpec(shape=(batch_size, 3), dtype=tf.float32, name='second'),
                    tf.TensorSpec(shape=(batch_size, 10), dtype=tf.float32, name='third')]

    # get the concrete function for the signature: static_tensorspec
    new_signature = tf.function(
        model.__call__).get_concrete_function(inputs=static_tensorspec, training=False, mask=None)
    return new_signature

def main():
    # create and save model
    model = create_test_model()
    path_tf = "./tf_model"
    tf.saved_model.save(model, path_tf)

    path_keras = "./keras_model"
    model.save(path_keras, overwrite=True, include_optimizer=False)

    # load model
    keras_model = tf.keras.models.load_model(path_keras)
    tf_model = tf.saved_model.load(path_tf)

    # extract concrete function with static batch size
    batch_size = 1
    # works for keras model
    keras_concrete = static_concrete_function(keras_model, batch_size)
    print("*" * 50)
    print(f"Keras Models: Input Signature\n{keras_concrete.structured_input_signature}")
    print("*" * 50)

    # fails to find matching signature for tensorflow saved model
    tf_concrete = static_concrete_function(tf_model, batch_size)


if __name__ == "__main__":
    main()

Hi @Zlorf, After loading you can get inputs by using

tf_model.signatures["serving_default"].structured_input_signature
#output

((),
 {'first': TensorSpec(shape=(1, 2), dtype=tf.float32, name='first'),
  'second': TensorSpec(shape=(1, 3), dtype=tf.float32, name='second'),
  'third': TensorSpec(shape=(1, 10), dtype=tf.float32, name='third')})

To get outputs you can use

tf_model.signatures["serving_default"].structured_outputs
#output
{'output_0': TensorSpec(shape=(1, 5), dtype=tf.float32, name='output_0')}

This is because tf.saved_model.savewill save a generic model. tf.saved_model.load will load it back as a generic _UserObject model. The _UserObject model will only have two sets of methods available:

  1. The tf.function methods that were attached to the model before you saved it
  2. The .signatures that were attached for serving.

Thank You.

Hi, thank you for trying to help me with my problem. I really appreciate the time and effort you took!

My problem still remains and I think there is a small misunderstanding. I have no problem in actually getting the structured_inputs and outputs of a model. I have already figured this out, although I’m a little confused that the structured inputs of tf.saved_models are unordered, while for Keras models, they are ordered.

My problem is and was that I can’t run new_signature = tf.function(model.__call__).get_concrete_function(inputs=static_tensorspec, training=False, mask=None) for SavedModels at all. Even when I extract the structured signature from the Keras model and use it as argument for inputs, I will get the same error about mismatching.

I hope it is a little more clear where I have a problem in my code.

Sorry I misunderstood your solution!

Your code totally solved my problem. Thank you very much!