Error Saving TF-Java Model / Exporter API

I keep getting the Error No Operation named [StatefulPartitionedCall_2:0] in the Graph
when Using SavedModelBundle.exporter to save the model

  • Tensorflow Python Version : 2.4.1
  • Tensorflow Java Version: 0.3.1
  • Os: Windows 10
  • GPU/CPU: CPU version
ConcreteFunction serveFunction = savedModel.function("serve_model");
SavedModelBundle.exporter(exportDir)
                .withFunction(serveFunction)
                .export();

To access and inspect Graph operations, i can see the StatefulPartitionedCall_2
But without the : at the end of the operation name.

Iterator<Operation> operationIterator  = serveFunction.graph().operations();
while(operationIterator.hasNext()){
    System.out.println(operationIterator.next().name());
}

code snippet output

Adam/iter
Adam/iter/Read/ReadVariableOp
Adam/beta_1
Adam/beta_1/Read/ReadVariableOp
Adam/beta_2
...
...
...
train_model_labels
StatefulPartitionedCall_1
saver_filename
StatefulPartitionedCall_2
StatefulPartitionedCall_3

Works fine when invoking directly the Op from session.runner()

String checkpointPath = "...";
session.runner()
        .feed("saver_filename:0", checkpointPath)
        .fetch("StatefulPartitionedCall_2:0").run() ;

Error could be reproduced using this scripts which defines than saves the model (credits to Thierry Herrmann)

import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras import layers

def make_model():

    class CustomLayer(keras.layers.Layer):
        def __init__(self, **kwargs):
            super().__init__(**kwargs)
            l2_reg = keras.regularizers.l2(0.1)
            self.dense = layers.Dense(1, kernel_regularizer=l2_reg,
                                      name='my_layer_dense')

        def call(self, data):
            return self.dense(data)
    inputs = keras.Input(shape=(8,))
    x1 = layers.Dense(30, activation="relu", name='my_dense')(inputs)
    outputs = CustomLayer()(x1)
    return keras.Model(inputs=inputs, outputs=outputs)


class CustomModule(tf.Module):

    def __init__(self):
        super(CustomModule, self).__init__()
        self.model = make_model()
        self.opt = keras.optimizers.Adam(learning_rate=0.001)

    @tf.function(input_signature=[tf.TensorSpec([None, 8], tf.float32)])
    def __call__(self, X):
        return self.model(X)

    # the my_train function processes one batch (one step): computes the loss and apply the
    # loss gradient to update the model weights
    @tf.function(input_signature=[tf.TensorSpec([None, 8], tf.float32), tf.TensorSpec([None], tf.float32)])
    def my_train(self, X, y):
        with tf.GradientTape() as tape:
            logits = self.model(X, training=True)
            main_loss = tf.reduce_mean(keras.losses.mean_squared_error(y, logits))
            # self.model.losses contains the reularization loss (see l2_reg above)
            loss_value = tf.add_n([main_loss] + self.model.losses)

        grads = tape.gradient(loss_value, self.model.trainable_weights)
        self.opt.apply_gradients(zip(grads, self.model.trainable_weights))
        return loss_value


# instantiate the module
module = CustomModule()


def save_module(module, model_dir):

    tf.saved_model.save(module, model_dir,
                        signatures={
                            'serve_model' :
                                module.__call__.get_concrete_function(tf.TensorSpec([None, 8], tf.float32)),
                            'train_model' :
                                module.my_train.get_concrete_function(tf.TensorSpec([None, 8], tf.float32),
                                                                      tf.TensorSpec([None], tf.float32))})


MODEL_OUTPUT_DIR ="..."
save_module(module, MODEL_OUTPUT_DIR)
```

For those interested to follow this topic, the discussion is happening on this GitHub issue.