Creating multipel parallel gradient tapes


I need to compute gradient of a scalar valued function with vector inputs for various inputs. I am currently doing something like below. But no matter what value I set for parallel_iterations of the tf.while_loop, it is only computing gradients for one input at a time. What am I missing?

import tensorflow as tf
import time

def f(x):
    x_hat = tf.signal.rfft(x)
    x_recon = tf.signal.irfft(x_hat)
    lnp = tf.reduce_sum(x_recon)
    return lnp

def ode_fn(x):
    return -x + 1.0

def integrate(x, nsteps, time_step):
    y = tf.TensorArray(dtype=tf.float32, size=nsteps)
    x_next = x
    for i in tf.range(nsteps):
        x_next = x_next + time_step * ode_fn(x_next)
        y = y.write(i, x_next)
    return y.stack()

def compute_grads(x):
    lnpgrad = tf.TensorArray(dtype=tf.float32, size=x.shape[0])

    def cond(i, lnpgrad):
        return tf.less(i, x.shape[0])

    def body(i, lnpgrad):
        x_i = x[i]
        with tf.GradientTape() as tape:
            lnp_i = tf.reduce_sum(integrate(x_i, 20000, 0.1))
        lnpgrad = lnpgrad.write(i, tape.gradient(lnp_i, x_i))
        return i + 1, lnpgrad

    i = tf.constant(0, dtype=tf.int32)
    i, lnpgrad = tf.while_loop(cond,
                               loop_vars=[i, lnpgrad],
    lnpgrad = lnpgrad.stack()
    return lnpgrad

x = tf.random.normal(shape=[10, 5000],

start = time.time()
fx_grads = compute_grads(x)
end = time.time()
print(f"Elapsed {end - start} seconds")

From my understanding, tf.GradientTape() computes once after which it releases all its resources. To continuously compute the gradient in the while loop, consider setting the persistent argument to True in tf.GradientTape()

persistent gradient tape is to compute the gradients over the same computation multiple times. But that is not what I am doing. For each row i in x I perform some transformations and compute the gradient of that with respect x[i]. Also each x[i] uses a different gradient tape instance/context. So I don’t think persistent tape is the problem here.

Anyway, I did try with a persistent tape and it didn’t make any difference in parallelizing the computation of gradients.