Creating multipel parallel gradient tapes

Hi,

I need to compute gradient of a scalar valued function with vector inputs for various inputs. I am currently doing something like below. But no matter what value I set for parallel_iterations of the tf.while_loop, it is only computing gradients for one input at a time. What am I missing?

import tensorflow as tf
import time

@tf.function
def f(x):
    x_hat = tf.signal.rfft(x)
    x_recon = tf.signal.irfft(x_hat)
    lnp = tf.reduce_sum(x_recon)
    tf.print(lnp)
    return lnp


@tf.function
def ode_fn(x):
    return -x + 1.0


@tf.function
def integrate(x, nsteps, time_step):
    y = tf.TensorArray(dtype=tf.float32, size=nsteps)
    x_next = x
    for i in tf.range(nsteps):
        x_next = x_next + time_step * ode_fn(x_next)
        y = y.write(i, x_next)
    return y.stack()


@tf.function
def compute_grads(x):
    lnpgrad = tf.TensorArray(dtype=tf.float32, size=x.shape[0])

    def cond(i, lnpgrad):
        return tf.less(i, x.shape[0])

    def body(i, lnpgrad):
        x_i = x[i]
        with tf.GradientTape() as tape:
            tape.watch(x_i)
            lnp_i = tf.reduce_sum(integrate(x_i, 20000, 0.1))
        tf.print(lnp_i)
        lnpgrad = lnpgrad.write(i, tape.gradient(lnp_i, x_i))
        return i + 1, lnpgrad

    i = tf.constant(0, dtype=tf.int32)
    i, lnpgrad = tf.while_loop(cond,
                               body,
                               loop_vars=[i, lnpgrad],
                               parallel_iterations=5)
    lnpgrad = lnpgrad.stack()
    return lnpgrad


x = tf.random.normal(shape=[10, 5000],
                     mean=10.0,
                     stddev=1.0,
                     dtype=tf.float32)

start = time.time()
fx_grads = compute_grads(x)
end = time.time()
print(f"Elapsed {end - start} seconds")

From my understanding, tf.GradientTape() computes once after which it releases all its resources. To continuously compute the gradient in the while loop, consider setting the persistent argument to True in tf.GradientTape()

persistent gradient tape is to compute the gradients over the same computation multiple times. But that is not what I am doing. For each row i in x I perform some transformations and compute the gradient of that with respect x[i]. Also each x[i] uses a different gradient tape instance/context. So I don’t think persistent tape is the problem here.

Anyway, I did try with a persistent tape and it didn’t make any difference in parallelizing the computation of gradients.