Backpropagation in tensorflow

Hi community,

I’m starting to learn tensorflow and I’m following the code examples from here

class MyModel(tf.keras.Model):
  def __init__(self):
    super(MyModel, self).__init__()
    self.conv1 = Conv2D(32, 3, activation='relu')
    self.flatten = Flatten()
    self.d1 = Dense(128, activation='relu')
    self.d2 = Dense(10, activation='softmax')

  def call(self, x):
    x = self.conv1(x)
    x = self.flatten(x)
    x = self.d1(x)
    return self.d2(x)
model = MyModel()

with tf.GradientTape() as tape:
  logits = model(images)
  loss_value = loss(logits, labels)
grads = tape.gradient(loss_value, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))

what I don’t understand is, if I have several inputs, how is a single loss_value giving me all the gradients?

My current understanding of backpropagation is that a single input gives a single output, which gives a single loss value, which is backpropagated to give you a single weight gradient.

This blog from Nielsen has helped me understand the backpropagation using python and numpy. http://neuralnetworksanddeeplearning.com/chap2.html

Could someone please guide me to resources that explain how tensorflow uses a single loss_value to actually give you the gradients for all points in a dataset?

@hargun3045,

Welcome to the Tensorflow Forum!

When using Tensorflow with multiple inputs, the usual practice is to calculate the loss value by averaging or summing the losses for all the inputs in your dataset referred as batch or mini-batch method. It involves processing a small portion of your dataset at a time instead of handling a single input individually or the entire dataset.

Thank you!

Thanks for the response.

But then how can a single loss value help calculate all the gradients?

According to what I’ve understood, with backpropagation, we can have one input, compute the loss, and use the intermediate activations to calculate the gradients. But it’s unclear how it’s happening in tensorflow. Could you guide me to some appropriate resources?

@hargun3045,

The loss is calculated across the batches and then the average loss is used to calculate the loss.

Please refer to the explanation from CS229 Lecture Notes

Thank you!