Model parallelism in Keras

Model parallelism in Keras does not seem possible as layers cannot be assigned to devices. Working with large data grids and large complex models eg using the Keras functional API, means that we run out of memory on GPUs very quickly so an approach to model parallelism is essential particularly for the application of AI to science problems. It would be very helpful to understand what is happening here?

If you are interested here is paper that talks about these issues related to weather forecasting:-

I first opened this as a bug under TensorFlow but have not had a response…

As far as I know, even if the model you wrote in the issue could work, it will encounter the curse of dimensions because of the transferring time, so you would expect a super low scaling efficiency for your simple implementation.

I would recommend you only do model parallelism unless you couldn’t even fit one reasonably small batch size. Because the model parallel could always expect smaller scaling efficiency than data parallel currently (from my knowledge).

Considering model parallelism, there is a work called mesh-tensorflow, you may refer to that for your purpose. Secondly, you could also write a custom distributed strategy if you wish (similar to TPUStrategy, but that’s too tedious for just one model testing and I wouldn’t recommend you to do so).

1 Like

Besides, there are two issues that I think why your model could not work (not tested, just thoughts):

  1. You split models on different devices, but the inputs is not explicitly transferred to the device. Thus maybe the computing graph may experience unfound variables(possibly).

  2. You may think the model parallelism as a serialized model: y1 = model_part0 (input) , y2 = model_part1(y1). Thus, you could not write something like

with tf.device("/device:GPU:0"):
  x = ...
with tf.device("/device:GPU:1"):
  x = ...

with tf.device("/device:GPU:2"):
  x = f(x) # <- This will get confused

This may make the graph have duplicated variable. That’s why the assertion assert x.device.endswith("/GPU:1") failed.

1 Like