Tensorflow multiprossesing model predication

Hi.

I havw a simple MNIST Keras model to make predictions and save the loss. I am running on a server with multiple CPUs, so I want to use multiprocessing for speedup.

I have successfully used multiprocessing with some basic functions, but for model prediction these processes never finish, while using the non-multiprocessing approach, they work fine.

I suspect that the issue might be with the model, as there is a single model it cannot be used in different parallel processes, so I loaded the model in each process, but it did not work.

My code is this:

from multiprocessing import Process
import tensorflow as tf

#make a prediction on a training sample
def predict(idx, return_dict):
  x = tf.convert_to_tensor(np.expand_dims(x_train[idx],axis=0))

  local_model=tf.keras.models.load_model('model.h5')
  y=local_model(x)
  print('this never gets printed')
  y_expanded=np.expand_dims(y_train[train_idx],axis=0)
  loss=tf.keras.losses.CategoricalCrossentropy(y_expanded,y)
  return_dict[i]=loss

manager = multiprocessing.Manager()
return_dict = manager.dict()
jobs = []

for i in range(10):
    p = Process(target=predict, args=(i, return_dict))
    jobs.append(p)
    p.start()
    
for proc in jobs:
    proc.join()

print(return_dict.values())

The print line in the predict function is never shown and the problem is with the model. Even without loading the model in the function and using a global one, the problem still persisted.

I followed this this thread but it did not work. My questions are now these:

  1. How to solve the model issue
  2. Can I use the same X_train for all the processes?