Speeding up TF predictions on small datasets


I’m looking for advice on how to speed up my model predictions in in my production setting.

I’ve just(!) moved from tensorflow 1.14 to tensorflow 2.3, and changes to model.predict have made the method significantly slower in 2.3 when called on very small datasets (discussed model.predict is much slower on TF 2.1+ · Issue #40261 · tensorflow/tensorflow · GitHub amongst other places). Unfortunately this is exactly my use case - predicting on small (~10 rows) datasets which cannot be constructed in advance.

Going from 1.14 to 2.3 I find that model.predict calls take approximately 5x as long. The main suggestion in discussion that i’ve found is to use a direct model call instead of model.predict - this is certainly faster, but still 2x slower than model.predict in 1.14. There’s some additional overhead from converting from numpy to a tf.Tensor, but this is negligible in comparison to the model call time.

Are there any obvious routes to further improve performance on 2.3? I understand that there are options like freezing the model, or converting to TFLite. Should I expect any of these or similar to meaningfully improve prediction speed?

@singular1ty To address the performance issues you’re facing with TensorFlow 2.3, try doing these things:

  1. Explore post-training quantization to reduce precision. Consider model pruning techniques to create a smaller model.
  2. Instead of using model.predict, try a direct session run. This eliminates some overhead and may improve speed.
  3. Utilize TensorFlow Profiler to identify specific bottlenecks in your model. This will guide further optimization efforts.

Experiment with different combinations and measure the impact on prediction speed.

Thanks for this:

  1. by direct session run, do you mean calling model(x) to predict? If so, whilst this is meaningfully faster than model.predict(x), it’s still substantially slower than the tf1.14 `model.predict(x).

  2. Of course I can do this - but I don’t think this issue relies on specific implementation of our models as it’s common across various architectures and reported by many others. It seems to be the additional machinery being set up in the (at least eager) version of tf2 model.predict.