Visualize TensorFlow graphs before and after Grappler passes?

I’ve been trying to visualize the graph of a tf.function with and without Grappler optimizations but so far I’m not managing to see any difference in the generated graphs.

Here is the process I followed: I took the code used in the “TensorFlow graph optimization with Grappler” tutorial to visualize the graph with and without constant folding. Then, I use the experimental_get_compiler_ir function with the “hlo” setting to generate the graph before XLA optimizations: simple_function.experimental_get_compiler_ir(x)(stage="hlo"). Finally, I use the interactive_graphviz XLA tool to visualize the graphs. Unfortunately, I’m not able to see any difference between the generated graphs: in both cases constant folding has not been optimized.

When designing the mentioned process, I was assuming that Grappler optimizations are executed first and then XLA optimizations take over. However, I just found a recent post suggesting that the use of XLA could disable Grappler passes, which doesn’t make sense to me but it’s consistent with my observation.

Is there anything wrong in the process I’m following? If so, is there an alternative way to visualize graphs before and after Grappler passes?

Here is the code I used to generate the graphs:

import numpy as np
import timeit
import traceback
import contextlib

import tensorflow as tf

@contextlib.contextmanager
def options(options):
  old_opts = tf.config.optimizer.get_experimental_options()
  tf.config.optimizer.set_experimental_options(options)
  try:
    yield
  finally:
    tf.config.optimizer.set_experimental_options(old_opts)

def test_function_1():
    @tf.function(jit_compile=True)
    def simple_function(input_arg):
        print('Tracing!')
        a = tf.constant(np.random.randn(2000,2000), dtype = tf.float32)
        c = a
        for n in range(50):
            c = c@a
        return tf.reduce_mean(c+input_arg)
    return simple_function

with tf.device("/gpu:0"):
    with options({'constant_folding': False}):
        print(tf.config.optimizer.get_experimental_options())
        simple_function = test_function_1()
        # Trace once
        x = tf.constant(2.2)
        simple_function(x)
        for i in range(2):
            print("Vanilla execution:", timeit.timeit(lambda: simple_function(x), number = 1), "s")
        # with open("no_constant_folding_opt_hlo.pbtxt", "w") as f:
        #     f.write(str(simple_function.experimental_get_compiler_ir(x)(stage="hlo")))

    with options({'constant_folding': True}):
        print(tf.config.optimizer.get_experimental_options())
        simple_function = test_function_1()
        # Trace once
        x = tf.constant(2.2)
        simple_function(x)
        for i in range(2):
            print("Constant folded execution:", timeit.timeit(lambda: simple_function(x), number = 1), "s")
        # with open("with_constant_folding_opt_hlo.pbtxt", "w") as f:
        #     f.write(str(simple_function.experimental_get_compiler_ir(x)(stage="hlo")))

Thanks for posting the snippet!

Where are you visualizing the pbtxt file? TensorBoard or Graphviz? Could you shed a bit more on that?

I was trying to visualize the output using the interactive_graphviz command from XLA tools (path of source is tensorflow/tensorflow/compiler/xla/tools/interactive_graphviz.cc).
However, the two graphs look exactly the same, which is consistent with the fact that the two .hlotxt files that my code outputs are the same (see screencapture).

I am thinking either:

  • something is wrong with my simple_function and it cannot exploit Grappler’s constant folding optimization
  • or I do not use the right function to output the HLO IR: .experimental_get_compiler_ir(x)(stage="hlo").

My goal is to visualize graphs before and after Grappler passes. Do you know a better way to do this?

PS: I failed to include the version in the original post, but I am using TensorFlow 2.8.3.

Maybe use other stage such as "optimized_hlo_dot"?

I was trying to visualize the output using the interactive_graphviz command from XLA tools

How does one use it?

My understanding is that "optimized_hlo_dot" outputs the HLO graph after the XLA optimization passes, which are different from Grappler passes.

However, I just found out about this post on the XLA Development Google group. It suggests that Grappler is not actually used when using XLA (jit_compile=True). This would explain the fact that the two HLO outputs of my first post are actually identical, because Grappler is not running.
Considering this, I should try to do this same experiment without using XLA if I want to be able to visualize the effect of Grappler passes.

On a side note, here is the usage notes in TF source code for interactive_graphviz if you are interested: