I want to use the gradients of one layer to calculate the gradients of the layer that comes before it.
My motivation for doing this is, when I tried to use model parallelism using tf.device, I found out that backpropagation has been running on CPU. The entire Backprop started running on a chosen tf.device only after I wrapped the call to GradientTape(when it computes the gradient) within the tf.device context manager. Since the model is split, I want the backprop of each partition to execute on the device where that partition is placed.
Ideally, I would like to find out a method with which this oversimplified pseudocode is possible.
with tf.device(device_3):
grad_3 = tf.gradients(loss, trainable_vars_of_partition_3)
with tf.device(device_2):
grad_2 = tf.gradients(grad_3, trainable_vars_of_partition_2)
with tf.device(device_1):
grad_1 = tf.gradients(grad_2, trainable_vars_of_partition_1)
grads = concat(grad_1, grad_2, grad_3)
If something like this exists then I would be overjoyed if you could point me in the right direction.
Unfortunately, I could not find something as simple as this. The next best approach that I could think of was using the gradients of one layer to find the gradients of a layer that comes before it. Using chain rule and backpropagation, I feel that this should be possible.
I created this toy example, solving which is the first step towards the final goal.
Let’s say we have a model with 3 dense layers without activations functions. X, Y as defined as follows:
x = tf.concat([tf.random.uniform([1, 10], minval=0, maxval=0.25),
tf.random.uniform([1, 10], minval=0.25, maxval=0.5),
tf.random.uniform([1, 10], minval=0.5, maxval=0.75),
tf.random.uniform([1, 10], minval=0.75, maxval=1.),
], axis = 0)
y = tf.constant(0., shape=[4, 1])
d1 = tf.keras.layers.Dense(5, name='d1')
d2 = tf.keras.layers.Dense(2, name='d2')
d3 = tf.keras.layers.Dense(1, name='d3')
I am using a tf.function in this toy example but an answer with eager mode enabled, using GradientTape will also be appreciated.
@tf.function
def tf_func(x, y, d1, d2, d3):
# Using shortforms of these function helped the code look neater and more readable to me.
g = tf.gradients
rs = tf.reduce_sum
rm = tf.reduce_mean
o1 = d1(x)
o2 = d2(o1)
o3 = d3(o2)
l = tf.reduce_mean(tf.square(o3 - y))
w3, w2, w1 = d3.trainable_variables, d2.trainable_variables, d1.trainable_variables
tf.print('actual grads' + '=' * 80)
dl_dw3 = g(l, w3)
dl_dw2 = g(l, w2)
tf.print('dl_dw2: \n', dl_dw2)
dl_dw1 = g(l, w1)
tf.print()
tf.print()
tf.print('reference grads' + '=' * 80)
dl_do1 = g(l, o1)
dl_do2 = g(l, o2)
tf.print('dl_do2: \n', dl_do2)
dl_do3 = g(l, o3)
dl_dw1 = g(l, w1)
dl_dw2 = g(l, w2)
dl_dw3 = g(l, w3)
do3_o2 = g(o3, o2)
do2_do1 = g(o2, o1)
do3_w3 = g(o3, w3)
do2_w2 = g(o2, w2)
do1_w1 = g(o1, w1)
tf.print('testing chain_rule method' + '=' * 80)
# Added a 't' before derivatives to differentiate between ref_grads and grads obtained using chain rule
tdl_do3 = g(l, o3) # same as ref_grads
tdo3_dw3 = g(o3, w3) # same as ref_grads
tdl_dw3 = [rm(tdl_do3) * tdo3_dw3[0], rm(tdl_do3) * tdo3_dw3[1]] # same as actual grads
tdo3_do2 = g(o3, o2) # same as ref_grads
tdl_do2 = tdo3_do2 * rm(tdl_do3, axis=0) # same as ref_grads
tf.print('tdl_do2: \n', tdl_do2)
tdo2_dw2 = g(o2, w2)
tf.print('tdo2_dw2: \n', tdo2_dw2)
tdl_dw2 = [tdo2_dw2[0] * rm(tdl_do2, axis=[1]), tdo2_dw2[1] * rm(tdl_do2, axis=[1])]
tf.print('tdl_dw2: \n', tdl_dw2)
return None
tf_func(x, y, d1, d2, d3)
The output was:
actual grads================================================================================
dl_dw2:
[[[-3.04819393 -1.30051827]
[5.02123785 2.14232159]
[-0.260933906 -0.111328]
[5.87596226 2.50699162]
[1.9655633 0.838611722]], [-4.69162369 -2.0016911]]
reference grads================================================================================
dl_do2:
[[[-0.43842113 -0.187053293]
[-0.889310718 -0.379426271]
[-1.41650343 -0.604354143]
[-1.94738865 -0.830857456]]]
testing chain_rule method================================================================================
tdl_do2:
[[[-0.43842113 -0.187053293]
[-0.889310718 -0.379426271]
[-1.41650343 -0.604354143]
[-1.94738865 -0.830857456]]]
tdo2_dw2:
[[[2.10966444 2.10966444]
[-3.48670244 -3.48670244]
[0.22972326 0.22972326]
[-3.95618558 -3.95618558]
[-1.3790133 -1.3790133]], [4 4]]
tdl_dw2:
[[[-2.47443795 -1.05572414]
[4.08957386 1.74482536]
[-0.26944378 -0.114958748]
[4.64023352 1.97976542]
[1.61745286 0.690089643]], [[-4.69162369 -2.0016911]]]
For some reason, gradients wrt weights in tdl_dw2 and dl_dw2 differ slightly. Every value in tdl_dw2 is slightly less than dl_dw2 even though the gradients wrt biases are the same. I cannot figure out why.
The gradient of loss wrt to w3 is as expected.
I used tf.reduce_mean to replicate what tf.gradients was doing internally as far as I understand. Please correct me if I am wrong.
From Tensorflow’s documentations:
gradients() adds ops to the graph to output the derivatives of ys with respect to xs. It returns a list of Tensor of length len(xs) where each tensor is the sum(dy/dx) for y in ys and for x in xs.
tf.gradients constructs symbolic derivatives of sum of ys w.r.t. x in xs.
Any guidance or help will be greatly appreciated, thank you.
Some Similar StackOverflow questions(there are many more):
- python - Compute gradients across two models - Stack Overflow
- python - Is it possible to acquire an intermediate gradient? (Tensorflow) - Stack Overflow
- automatic differentiation - Breaking TensorFlow gradient calculation into two (or more) parts - Stack Overflow
Here is a colab notebook with the code: