I’m trying to understand the implementation of a custom Model in keras, and particularly with respect to the calculation of the gradients from the loss.
From what I understand, when implementing a custon train_step, loss should return a scalar value, instead of a loss value for each batch sample (batch_size, ). This is the case in the implementation of the following VAE (e.g., Variational AutoEncoder ), and is stated more generally here: Losses , where it suggests that in a custom training loop, a Class instance loss should be used, retuning a scalar.
My question is, how does the tape.gradient calculates the average gradient? If we average the losses of the batch, don’t we lose the per sample gradient information? Theoretically, I understand that we need to calculate the average gradient of the per sample gradients in the batch, and not the gradient of the average loss. Here is a very nice explanation why the two are not equivalent: backpropagation - What exactly is averaged when doing batch gradient descent? - Artificial Intelligence Stack Exchange
I would appreciate any clarifications and help!
Certainly! When training a neural network with mini-batches in TensorFlow, the loss computed for a batch is typically a scalar value representing the average of individual losses for each sample in the batch. This scalar loss is used to calculate gradients with respect to the model’s parameters using TensorFlow’s
Here’s why averaging works and is commonly used:
Stability: Averaging loss over a batch stabilizes training by reducing noise from individual sample losses.
Scale Consistency: It ensures the magnitude of loss (and thus gradients) doesn’t depend on batch size, making training behavior more consistent.
The gradient of this average loss effectively approximates the average of the individual gradients for each sample. This is computationally efficient and is standard practice due to the linearity of differentiation: the gradient of the sum (or average) of losses is equal to the sum (or average) of individual gradients.
Therefore, even though the loss is a scalar average, the gradient calculation does not lose per-sample gradient information; rather, it computes an efficient estimate of the average gradient across the batch. This is how TensorFlow’s automatic differentiation handles gradient calculations in mini-batch training.
@Tim_Wolfe Thank you very much for your answer.
After thinking about the problem, my intuition was that GradientTape actually keeps track of the computational graph, so by providing an average loss value, it still has a way to back propagate the error in the individual samples and perform gradient averaging across the batch. Not sure how gradientTape is implemented, but it wouldn’t surprise me if that was the case.
My concern with calculating a gradient over the average losses is that although gradient is a linear operation, it requires that the gradients are evaluated at the same point x of the loss function (the link I posted describes this more eloquently). However, this is not the case for the batch, which includes different inputs. So in a way, we’re differentiating a batch_size number of loss functions, each one shaped by their corresponding input, which gives us a gradient direction for the network’s weights. Averaging these gradients makes sense. But if we average the loss values, what does this loss represents, and in which loss function? This is still mathematically confusing to me.