I’m trying to implement a call()
method that receives a tensor
as argument and can be send to other three different models based on its shape
or its dimension
. The following implementation uses if-else
conditions and it works well during the training procedure:
def call(self, tensor: tf.Tensor) -> tf.Tensor:
if len(tensor.shape) == 2: # Shape [B, 2]
return self.model_1(tensor)
if tensor.shape[-1] == 33: # Shape [B, H, W, 33]
return self.model_2(tensor)
return self.model_3(tensor) # Shape [B, H, W, 2]
However, when I save the trained model using model.save()
method and load it with tf.keras.models.load_model()
method, the if-else
condition doesn’t work anymore. I tried to convert the call()
method to use tf.cond()
and tf.case()
, but it didn’t work at all.
Could you help me to find a documentation that shows how to implement this features, please?