I’m trying to essentially create a 3-D tensor from the indexed rows of a 2-D tensor. For example, assuming I have:

```
A = tensor(shape=[200, 256]) # 2-D Tensor.
Aidx = tensor(shape=[1000, 10]) # 2-D Tensor holding row indices of A for each of 1000 batches.
```

I wish to create:

```
B = tensor(shape=[1000, 10, 256]) # 3-D Tensor with each batch being of dims (10, 256) selected from A.
```

Right now, I’m doing this in a memory inefficient manner by doing a `tf.broadcast()`

and then using a `tf.gather()`

. This is very fast, but also takes up a lot of RAM:

```
A = tf.broadcast_to(A, [1000, A.shape[0], A.shape[1]])
A = tf.gather(A, Aidx, axis=1, batch_dims=1)
```

Is there a more memory efficient way of doing the above operation? Naively, one can make use of a for loop, but that is very compute inefficient for my use case. Thanks in advance!