Merging multiple models in one

I’m trying to implement a call() method that receives a tensor as argument and can be send to other three different models based on its shape or its dimension. The following implementation uses if-else conditions and it works well during the training procedure:

def call(self, tensor: tf.Tensor) -> tf.Tensor:
    if len(tensor.shape) == 2: # Shape [B, 2]
        return self.model_1(tensor)
    if tensor.shape[-1] == 33: # Shape [B, H, W, 33]
        return self.model_2(tensor)
    return self.model_3(tensor) # Shape [B, H, W, 2]

However, when I save the trained model using model.save() method and load it with tf.keras.models.load_model() method, the if-else condition doesn’t work anymore. I tried to convert the call() method to use tf.cond() and tf.case(), but it didn’t work at all.

Could you help me to find a documentation that shows how to implement this features, please?

I have tried the following code, but it doesn’t work even in the training procedure:

def call(self, tensor: tf.Tensor) -> tf.Tensor:
    return tf.case([
        (tf.equal(tf.size(tf.shape(tensor)), 2), self.model_1(tensor)),
        (tf.equal(tf.shape(tensor)[-1], 33), self.model_2(tensor))
    ], default=self.model_3(tensor), exclusive=True)

@ fabricionarcizo & Welcome to the Tensorflow forum.
You checked the tensor that is taken by your call(self, tensor: tf.Tensor) function matches one of the conditions, right?
Maybe you can share a reproducible example e.g. in Colab?
Thank you.

I have solved my problem. It is not necessary to implement the call() method using tf.cond() or tf.case() methods. I can implement it using if-else conditions like in this example:

def call(self, tensor: tf.Tensor) -> tf.Tensor:
    if len(tensor.shape) == 2: # Shape [B, 2]
        return self.model_1(tensor)
    if tensor.shape[-1] == 33: # Shape [B, H, W, 33]
        return self.model_2(tensor)
    return self.model_3(tensor) # Shape [B, H, W, 2]

I have used the pretty_printed_concrete_signatures() method to see all of the available traces in my model. I concluded that when I reload the model using tf.keras.models.load_model(), I couldn’t access all sub-models of my project (i.e., self.model_1, self.model_2 and self.model_3. I must use the argument custom_objects to define the custom classes used during deserialization (see more on tf.keras.models.load_model). Therefore, I must load my model using:

model = tf.keras.models.load_model(
    model_path, custom_objects={"CustomModel": CustomModel}
)

Now, when I print the available traces, I got the following result:

training_step(batch)
  Args:
    batch: (<1>, <2>, <3>, <4>, <5>)
      <1>: float32 Tensor, shape=(256, 48, 64, 3)
      <2>: float32 Tensor, shape=(256, 48, 64, 3)
      <3>: float32 Tensor, shape=(256, 48, 64, 3)
      <4>: float32 Tensor, shape=(256, 2)
      <5>: float32 Tensor, shape=(256, 48, 64, 14)
  Returns:
    NoneTensorSpec()

My problem is solved!

2 Likes

Now, I’m facing another problem. I would like to save these three modes into a single model.tflite file. I couldn’t figure out how to save the model when the tensor argument can have multiple shapes. I have tried to create a converter with the following code:

converter = tf.lite.TFLiteConverter.from_concrete_functions(
    [ model.call.get_concrete_function(tf.zeros((B, 2), tf.float32),
      model.call.get_concrete_function(tf.zeros((B, H, W, 33), tf.float32),
      model.call.get_concrete_function(tf.zeros((B, H, W, 2), tf.float32) ], model)
tflite_model = converter.convert()

However, only the first shape is recognized in the converter object. I also tried to use the annotation @tf.function in the call() method. But, the input_signature only accepts one argument. Where could I find a documentation to help with this issue, please?

I want to apply polymorphism to the call() method, but I couldn’t find the proper way. My solution was creating different methods for each model:

@tf.function
def model_1(self, tensor: tf.Tensor) -> tf.Tensor:
    output = self.model_1(tensor)
    return { "output": output }

@tf.function
def model_2(self, tensor: tf.Tensor) -> tf.Tensor:
    output = self.model_2(tensor)
    return { "output": output }
	
@tf.function
def model_3(self, tensor: tf.Tensor) -> tf.Tensor:
    output = self.model_3(tensor)
    return { "output": output }

Then, create the converter using:

converter = tf.lite.TFLiteConverter.from_concrete_functions(
    [ model.method_1.get_concrete_function(tf.zeros((B, 2), tf.float32),
      model.method_2.get_concrete_function(tf.zeros((B, H, W, 33), tf.float32),
      model.method_3.get_concrete_function(tf.zeros((B, H, W, 2), tf.float32) ], model)
tflite_model = converter.convert()

I have to use the methods’ name (e.g., model.method_1(tensor)) instead of execute the directly from the model’s object (e.g., model(tensor)). But, at least I could export the three models in a single *.tflite file.