Conversion of TensorFlow GraphDef to only Tosa MLIR using Python APIs

Hi, I am new toTensorFlow MLIR and want to convert a model GraphDef into a TOSA MLIR using Python Experimental APIs. Was able to do so like this-

import tensorflow as tf

tf.compat.v1.disable_eager_execution()
input_shape = (20, 50, 100, 32)
model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(32, 3, 2, input_shape=input_shape[1:], activation='relu'),
    tf.keras.layers.MaxPooling2D(3, 2)
])
model.build()
graph = tf.compat.v1.get_default_graph()
graph_def = graph.as_graph_def()


mlir = tf.mlir.experimental.convert_graph_def(graph_def, pass_pipeline='tf-standard-pipeline, func.func(tosa-legalize-tf)', show_debug_info=False)

The output is like this -

module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 1645 : i32}} {
  func.func @main() {
    %0 = "tosa.const"() <{value = dense<0.000000e+00> : tensor<32xf32>}> : () -> tensor<32xf32>
    %1 = "tosa.const"() <{value = dense<[898181614, 0]> : tensor<2xi32>}> : () -> tensor<2xi32>
    %2 = "tosa.const"() <{value = dense<3> : tensor<i32>}> : () -> tensor<i32>
    %3 = "tosa.const"() <{value = dense<-0.102062076> : tensor<f32>}> : () -> tensor<f32>
    %4 = "tosa.const"() <{value = dense<[3, 3, 32, 32]> : tensor<4xi32>}> : () -> tensor<4xi32>
    %5 = "tosa.const"() <{value = dense<0.204124153> : tensor<f32>}> : () -> tensor<f32>
    %6 = "tf.VarHandleOp"() {_class = ["loc:@conv2d/bias"], allowed_devices = [], container = "", debug_name = "conv2d/bias/", device = "", shared_name = "conv2d/bias"} : () -> tensor<!tf_type.resource<tensor<32xf32>>>
    "tf.AssignVariableOp"(%6, %0) {device = "", validate_shape = false} : (tensor<!tf_type.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
    %7 = "tf.VarHandleOp"() {_class = ["loc:@conv2d/kernel"], allowed_devices = [], container = "", debug_name = "conv2d/kernel/", device = "", shared_name = "conv2d/kernel"} : () -> tensor<!tf_type.resource<tensor<3x3x32x32xf32>>>
    %key, %counter = "tf.StatelessRandomGetKeyCounter"(%1) {_class = ["loc:@conv2d/kernel"], device = ""} : (tensor<2xi32>) -> (tensor<1xui64>, tensor<2xui64>)
    %8 = "tf.StatelessRandomUniformV2"(%4, %key, %counter, %2) {_class = ["loc:@conv2d/kernel"], device = ""} : (tensor<4xi32>, tensor<1xui64>, tensor<2xui64>, tensor<i32>) -> tensor<3x3x32x32xf32>
    %9 = tosa.mul %8, %5 {shift = 0 : i8} : (tensor<3x3x32x32xf32>, tensor<f32>) -> tensor<3x3x32x32xf32>
    %10 = tosa.add %9, %3 : (tensor<3x3x32x32xf32>, tensor<f32>) -> tensor<3x3x32x32xf32>
    "tf.AssignVariableOp"(%7, %10) {device = "", validate_shape = false} : (tensor<!tf_type.resource<tensor<3x3x32x32xf32>>>, tensor<3x3x32x32xf32>) -> ()
    return
  }
}

It contains a mix of TF and TOSA dialects. How can I convert it only to only TOSA dialect?

To convert a TensorFlow model to a purely TOSA MLIR dialect using Python APIs, you can follow these simplified steps:

  1. Prepare the Model: Disable eager execution and define your TensorFlow model using the tf.keras API.
  2. Extract GraphDef: Obtain the GraphDef from the TensorFlow computational graph.
  3. Convert to MLIR: Use the tf.mlir.experimental.convert_graph_def function with a conversion pipeline that includes tosa-legalize-tf to convert TensorFlow operations to TOSA operations.
  4. Adjust the Conversion Pipeline: Ensure the pipeline includes additional MLIR passes for full conversion to TOSA and cleanup. You might need to include passes for further legalization and optimization to remove any remaining TensorFlow-specific operations.

Here’s a condensed version of the code to achieve this:

pythonCopy code

import tensorflow as tf

tf.compat.v1.disable_eager_execution()

# Define your TensorFlow model
input_shape = (20, 50, 100, 32)
model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(32, 3, 2, input_shape=input_shape[1:], activation='relu'),
    tf.keras.layers.MaxPooling2D(3, 2)
])
model.build()

# Get the GraphDef
graph_def = tf.compat.v1.get_default_graph().as_graph_def()

# Convert to MLIR with TOSA conversion
mlir = tf.mlir.experimental.convert_graph_def(
    graph_def, 
    pass_pipeline='tf-standard-pipeline, func.func(tosa-legalize-tf)',
    show_debug_info=False
)

print(mlir)

This process might require further adjustment and exploration of MLIR passes to ensure a complete conversion to the TOSA dialect without retaining any TensorFlow operations.