How to set the device for the tensor?

I am curious about how to set the Device in TF. I want to implement a custom distributed data parallel algorithm, and I want to say , for example, split input tensor x into three parts and transfer it to three devices.

so basically, I want to

x0, x1, x2 = tf.split(x, num_or_size_splits=3, axis=1)
x0 = x0.to('device:0')
x1 = x1.to('device:1')
x2 = x2.to('device:2')

But this seems quite impossible in TF.

I found one is about colocation_graph, should I use that?

1 Like

you can do that using the with(device)

does it help?

1 Like

Thanks for the reply, sorry for this unclear question. The with context manager only work for python, IMHO.

However, if I want to implement a data parallel, I would have to rewrite the default TF’s pass, in that case, how would I handle this in C++? Because as far as I know, TF’s tensor does not have the device’s information.

1 Like

Humm, I don’t know.

Is this for the training step?
I lack the background but maybe this Distributed training with TensorFlow  |  TensorFlow Core might be able to give some insights

1 Like

Are you looking for creating your own custom distributed strategy?

Cause I don’t think that we officially support this:

2 Likes

Thanks for the reply! Yes, I am trying to create my own custom distributed strategy, but it seems that doing this in TF is causing a lot of trouble…

1 Like

You can try to look at

2 Likes