I’m trying to implement a
call() method that receives a
tensor as argument and can be send to other three different models based on its
shape or its
dimension. The following implementation uses
if-else conditions and it works well during the training procedure:
def call(self, tensor: tf.Tensor) -> tf.Tensor: if len(tensor.shape) == 2: # Shape [B, 2] return self.model_1(tensor) if tensor.shape[-1] == 33: # Shape [B, H, W, 33] return self.model_2(tensor) return self.model_3(tensor) # Shape [B, H, W, 2]
However, when I save the trained model using
model.save() method and load it with
tf.keras.models.load_model() method, the
if-else condition doesn’t work anymore. I tried to convert the
call() method to use
tf.case(), but it didn’t work at all.
Could you help me to find a documentation that shows how to implement this features, please?