What all MLIR dialects can support both Training and Inference using TensorFlow?

Hi, the TensorFlow provides a lot of option to translate the TF graph to MLIR dialects like stablehlo, hlo, mlprogram etc. So, what all dialects supports training. And also, is it possible to convert the graph_def to only one MLIR dialect out of these. I had used TF MLIR Experimental APIs for converting to these mlir dialects like this:

import keras
import tensorflow as tf
from keras.models import Sequential
from keras.layers import Conv2D
tf.compat.v1.disable_eager_execution()

model = Sequential()
model.add(Conv2D(filters=1, input_shape=(32, 32, 3), kernel_size=(2,2), strides =(2,2), padding='Valid', use_bias=False))

model.compile(optimizer='adam', loss='mse', metrics=['accuracy'], run_eagerly=False)

graph = tf.compat.v1.get_default_graph()
graph_def = graph.as_graph_def()

mlir = tf.mlir.experimental.convert_graph_def(graph_def, pass_pipeline='tf-standard-pipeline, tf-stablehlo, stablehlo-legalize-to-hlo, hlo-legalize-to-stablehlo', show_debug_info=False)

print(mlir)

Some of the pipelines gave mixed dialect mlir. What all dialects support training and contains only one dialect?

TensorFlow’s integration with MLIR (Multi-Level Intermediate Representation) offers various dialects for optimizing and executing machine learning models. These dialects serve different purposes, ranging from representing high-level TensorFlow operations to lower-level operations closer to the hardware level. Regarding your query about MLIR dialects that support both training and inference in TensorFlow, here’s a breakdown:

MLIR Dialects for Training and Inference:

  1. TensorFlow Dialect (tf Dialect): This is a direct representation of TensorFlow graphs in MLIR. It’s suitable for both training and inference as it closely mirrors TensorFlow’s own operations and functionalities.
  2. TensorFlow HLO (High-Level Optimizer) Dialect (mhlo): HLO is an intermediate representation used in the XLA compiler. The mhlo dialect in MLIR represents these operations and is capable of supporting both training and inference. It is more low-level compared to the tf dialect and is used for optimizations closer to the hardware level.
  3. StableHLO Dialect: StableHLO is an effort to provide a stable subset of HLO operations. It aims to support a broad range of machine learning operations, including those needed for training.

Conversion to a Single Dialect:

  • The ability to convert a TensorFlow graph into a single MLIR dialect depends on the complexity of the operations in the model and the target dialect’s capability to represent those operations.
  • For example, converting a TensorFlow graph entirely into the tf dialect is usually straightforward because this dialect is designed to represent TensorFlow operations directly. However, converting complex models entirely into a lower-level dialect like mhlo or StableHLO might be more challenging due to the abstractions and optimizations involved.

Experimental APIs and Mixed Dialects:

  • The TensorFlow MLIR Experimental APIs you’ve used can indeed produce mixed dialect MLIR because different parts of the TensorFlow graph might be best represented or optimized in different dialects.
  • The choice of dialect for a specific part of the graph depends on the optimization and execution requirements. For example, certain operations might be better optimized in the mhlo dialect, while others are more naturally represented in the tf dialect.

Conclusion:

  • Both the TensorFlow (tf) and TensorFlow HLO (mhlo) dialects can support training and inference. StableHLO also aims to support a wide range of machine learning operations including those required for training.
  • Converting a TensorFlow graph to a single MLIR dialect is possible but depends on the compatibility of the model operations with the target dialect. The tf dialect is generally a safe choice for a direct representation, but if you’re aiming for optimizations at a lower level, mhlo or StableHLO might be more appropriate, albeit potentially leading to mixed dialect outputs depending on the model’s complexity.

Thank you, Tim. So, the lowering to a single dialect is highly dependent on model architecture and the operations used.