How to enable GSPMD?


We are trying to utilize GSPMD ( on a server with 8 GPUS, using the API ‘xla_sharding.mesh_split’.

To dump the IRs, we enable the environment variables as: TF_DUMP_GRAPH_PREFIX=/tmp/generated
TF_XLA_FLAGS="–tf_xla_clustering_debug --tf_xla_auto_jit=2"
XLA_FLAGS="–xla_dump_hlo_as_text --xla_dump_to=/tmp/generated"

But, only 4 IRs are saved: mark_for_compilation.pbtxt, mark_for_compilation_annotated.pbtxt, before_mark_for_compilation.pbtxt, before_increase_dynamism_for_auto_jit_pass.pbtxt. None of them is related to SPMD pass.

It seems that our current run does NOT enable GSPMD functionality at all. Is there any tutorial or instructions for us to follow to enable GSPMD on multiple GPUS?

Our testcode is follows (copied and modified from tensorflow/ at r2.7 · tensorflow/tensorflow · GitHub):

from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.framework import test_util
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
import numpy as np
from tensorflow.python.eager import def_function

class XlaShardingTest(test_util.TensorFlowTestCase):
def test_dot_split(self):
def split_helper(tensor):
device_mesh = np.array([[0, 1, 2, 3], [4, 5, 6, 7]])
split_tensor = xla_sharding.mesh_split(tensor, device_mesh, [0, 1])
self.assertIsInstance(split_tensor, ops.Tensor)
split_sharding = xla_sharding.get_tensor_sharding(split_tensor)
split_shape = xla_sharding.get_sharding_tile_shape(split_sharding)
expected_shape = [2, 4]
self.assertEqual(expected_shape, split_shape)

  y_tensor = array_ops.ones([8, 8], dtype=dtypes.float32)
  y_split = xla_sharding.mesh_split(y_tensor, device_mesh, [0, 1])
  result = math_ops.matmul(split_tensor, y_split)
  device_mesh = np.array([[0, 1], [2, 3], [4, 5], [6, 7]])
  result = xla_sharding.mesh_split(result, device_mesh, [0, 1])
  result = math_ops.sqrt(result)
  result = xla_sharding.mesh_split(result, device_mesh, [1, 0])
  return result

in_tensor = 2 * np.sqrt(2) * array_ops.ones([8, 8], dtype=dtypes.float32)
result = split_helper(
    array_ops.ones([8, 8], dtype=dtypes.float32))
self.assertAllEqual(in_tensor, result)

if name == “main”:
xlasharding = XlaShardingTest()

Hi @Xiaoda_Zhang ,the test code works fine as per gist but not sure about the GSPMD enabling things.Anyway can you go through the content of the generated files from code execution and share if any useful information there.
Could you please confirm details like how you have enabled the GPUs and which distribution strategy followed etc to dig into the problem.