What is `worker` and `use_multiprocessing` in `model.predict`?

Hi Team, I have built the image classification model and I need to do a batch transform. I did an experiment by looping the input data to the model and predicted it. However, it took 42sec for 50 images. But then, I would like to reduce the inference time because I have got 70K images per day to do inference.

a while ago, I looked into the API documentation and noticed a few additional parameters in the model.predict one is worker and the other is use_multiprocessing.

please find the below snippet which I got from TF documentation.

predict(
    x,
    batch_size=None,
    verbose='auto',
    steps=None,
    callbacks=None,
    max_queue_size=10,
    workers=1,
    use_multiprocessing=False
)

Let’s say if I provide batch_size is 1000 and worker is 12, and use_multiprocessing is True. Will, it batch transform in parallel?

as my sample size is 10000, does it pass 1000 samples(which is batch_size) to each worker at the same time? since I made use_multiprocessing=True

predict(
    x,
    batch_size=1000,
    workers=10,
    use_multiprocessing=True
)

Essentially, I’m trying to achieve faster batch inference/prediction, so initially, I was looking into multiprocessing and now I got to know that there is an in-built parameter in TensorFlow, but I would like to confirm/verify it with you. whether it does the same job as my expectation or something else?

#tensorflow #tf2 #keras #inference #prediction

could you please provide an update to this?

The docstring for predict has a bit more info:

workers: Integer. Used for generator or `keras.utils.Sequence` input
    only. Maximum number of processes to spin up when using
    process-based threading. If unspecified, `workers` will default
    to 1.
use_multiprocessing: Boolean. Used for generator or
    `keras.utils.Sequence` input only. If `True`, use process-based
    threading. If unspecified, `use_multiprocessing` will default to
    `False`. Note that because this implementation relies on
    multiprocessing, you should not pass non-picklable arguments to
    the generator as they can't be passed easily to children processes.

Without seeing your dataset, it’s hard to answer. The docstring states that these args are for process-based multiprocessing, but only for generator/Sequence input. If your data is appropriate, I imagine it will batch transform in parallel.

Depending on what you’re trying to do, you can also run multiple copies of the model, either on the same machine or across a cluster, e.g. by sharding your dataset.

If you’re looking for general performance advice, consider a post asking that specifically (to avoid the XY problem). And check out these great docs: