Train a multi head model with layers updated separately based on data


I’m interested in building a model with the following structure:
input_data → layer1
layer1 → layer21 → pred1
lauer1 → layer22 → pred2

layer1 are connected with two layers (layer21, layer22) which connects with the final output.

one tricky thing is I want the model training to be affected by feature value, assume there is a feature A with value (0, 1), I want the model to

update gradient of layer21 → pred1 when feature value = 0
update gradient of layer22 → pred2 when feature value = 1

Looking online, seems I have to implement cusomtized training function? just wondering if there are any preexisting layer def I can use to achieve this?