What is the implication of a loop inside GradientTape context

Hi,

General usage is to create a tape, perform a forward pass on the model, compute the gradients and then apply the gradients using the optimizer.

I understand that the tape will record the intermediate values and will make the computation of gradients easier.

I am interested in what happens underneath if you have a for-loop inside the tape context. Every iteration in the for-loop essentially runs the forward-pass of the model. Something like this -

num_epochs = 10
n_iterations_per_epoch = 1000
n_batch_per_train_setp = 5
for epoch in range(num_epochs):  
  for i in range(n_iterations_per_epoch):    
    with tf.GradientTape() as tape:
      accum_loss = tf.Variable(0.)
      for j in range(n_batch_per_train_setp):
        X, Y = ... # got from the tf.data.Dataset...    
        logits =  model(X)              
        # accumulate the loss 
        accum_loss = accum_loss + loss_fn(Y, logits)
        
    # compute the gradient on the accumulated loss
    gradients = tape.gradient(accum_loss, model.trainable_weights)
    # compute the average gradient
    gradients = [g*(1./n_batch) for g in gradients]
    optimizer.apply_gradients(zip(gradients, model.trainable_weights))

the total loss is accumulated and the gradients are computed on this total loss.

This is essentially a way to use mini-batches of smaller size (to fit in GPU memory). I am aware of the alternative way of accomplishing it i.e. instead of accumulating the loss, accumulate the gradient for each mini-batch. Sum or average the gradients and then apply them using an optimizer.

Here, I am interested in understanding what is happening in TensorFlow with respect to such code. Does it create/allocate dedicated memory for intermediate values for all the forward passes or does it overwrite the allocated memory per trainable variable?

Would appreciate it if someone could provide the necessary insights.

Regards
Kapil

1 Like

Has anyone found an answer to this? Iā€™m having the same concern.