How to process sparse “subnetworks” in parallel

Hi everyone, I have been trying to figure out this problem for a while and want to see if anyone is able to help me out.

Normally we are concerned with how many samples we are able to process in a single batch. Without bringing in too many unnecessary details, the size of my data and model is fairly small and I have extra GPU space that I want to use to evaluate multiple sets of weights in parallel (like samples are processed in parallel). What I mean by processing multiple sets of weights, is that we will use the same data with the same base model architecture.

I will try to use a minimal example to demonstrate…

Let’s say we have 2 samples, with 2 features

samples = [[3, 7], [9, 4]]

Then we also have 2 sets of weights (a weight and bias for each feature)

w1 = [[2, 7], [9, 5]]
w2 = [[8, 6], [6, 3]]

Our model is sparse and has two output nodes, so we basically just do a simple linear operation on each input feature.

set model weights to w1

output = [[3 * 2 + 7, 7 * 9 + 5], [9 * 2 + 7, 4 * 9 + 5]] - > [[13, 68], [25, 41]]

set model weights to w2

output = [[3 * 8 + 6, 7 * 6 + 3], [9 * 8 + 6, 4 * 6 + 3]] → [[30, 45], [78, 27]]

This requires to separate calls for model.predict(), which are done sequentially.

Now what I want to do is create a new model that uses both sets of weights at the same time to produce output like

[[[13, 68], [25, 41]], [[30, 45], [78, 27]]]

The dimension/structure of the output is just for example (not sure if it should be different), however the main goal is that we compute both sets of weights simultaneously to product the output from both sets of model weights, in the same predict(). Meaning the new model architecture should have 8 weights total, following the example above.

I believe that theoretically this should be possible, given that convolutional networks do this sort of thing to some degree, the only difference is that the nodes join back up, whereas here only the input nodes are shared, and the “subnetworks” down to their respective output nodes are kept separate.

Ideally this should be extensible to N sets of weights.