Load and run a model with Convolutions and Pooling in Java

I am trying to load a Sequential Model that I trained in Python and saved using “tf.saved_model.save()”. When I attempt to load the model in Java using the SavedModelBundle class and use it to predict output based on an input, I am running into issues if my model has Convolution or Max Pooling layers. I get the following error:

org.tensorflow.exceptions.TFInvalidArgumentException: Input to reshape is a tensor with 4282281 values, but the requested shape has 2569368600

Also, when compiling and saving the model in Python, I get these warnings which might be relevant:

WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op while saving (showing 1 of 1). These functions will not be directly callable after loading.```

If I don’t use MaxPooling or Convolution layers everything works as intended.

How can I execute models that have these types of layers in Java?

Can you check if the saved model loads & executes correctly in Python or TF-serving? If that error is coming out of the inside layers of the model (rather than the input) then it’s not something TF-Java has control of. We use the C API to load in the graph inside a saved model and execute it in the native runtime which should be shared across python, java, serving, go and Rust.

The model does load and execute correctly in Python. I have not tried TF-Serving but will try that next.

Edit: The model also works with TF Serving so I do not think the issue is with the model itself.