I am wondering if there is a way get the HLO of an ML model defined in TF without having to instantiate the model’s weights. I know several ways to get the HLO of a model:
- using the XLA_FLAGS="–xla_dump_to=/tmp/generated" TF_XLA_FLAGS="–tf_xla_auto_jit=2" environment variable before running your TF model
- using tf.function’s
- using the SavedModel flow described here Utilisation du format SavedModel | TensorFlow Core (which uses tf.function’s ConcreteFunction under the hood)
Unfortunately, these approaches require weight initialization either via calling the Keras
build function explicitly (in the XLA_FLAGS case), running the model on some inputs (again, in the XLA_FLAGS case), or with the implicit or explicit calling of tf.function’s
get_concrete_function method. These approaches cause the weights to be initialized on the target device, causing an OOM error when instantiating a model with a total parameter byte count that is greater than the amount of main memory available to the device.
Is there a way to get the HLO graph of a TF model without instantiating the weights? In HLO graphs I’ve visualized before, weights are always represented by Parameter nodes, and are passed as arguments to the train step, so I’m unsure of why the weights themselves must be initialized before I can convert to and HLOModule. Does anyone have any insight?