Is there a way to create multiple gather() outputs and stack them parallely in a compute-and-memory-efficient manner?

I’m trying to essentially create a 3-D tensor from the indexed rows of a 2-D tensor. For example, assuming I have:

A = tensor(shape=[200, 256]) # 2-D Tensor.
Aidx = tensor(shape=[1000, 10]) # 2-D Tensor holding row indices of A for each of 1000 batches.

I wish to create:

B = tensor(shape=[1000, 10, 256]) # 3-D Tensor with each batch being of dims (10, 256) selected from A.

Right now, I’m doing this in a memory inefficient manner by doing a tf.broadcast() and then using a tf.gather(). This is very fast, but also takes up a lot of RAM:

A = tf.broadcast_to(A, [1000, A.shape[0], A.shape[1]])
A = tf.gather(A, Aidx, axis=1, batch_dims=1)

Is there a more memory efficient way of doing the above operation? Naively, one can make use of a for loop, but that is very compute inefficient for my use case. Thanks in advance!