Tensorflow transforms passes hardcoded on "main" function

Hello,

I notice at least about half a dozen tensorflow transforms passes (like tensor-list-ops-decomposition) hardcoded to work only on a single function named “main”.
(tensor_list_ops_decomposition.cc on github)

  void TensorListOpsDecompositionPass::runOnOperation() {
    auto module = getOperation();
    auto main = module.lookupSymbol<FuncOp>("main");
    if (!main) return;
    if (failed(DecomposeTensorListOps(&main.front(), module))) {
      signalPassFailure();
    }
  }

Is there an assumption that the canonical form is one where the “entry function” is named “main”? This isn’t true for an import/translation from a tf.function where the entry function has the tf.function’s name with a suffix/prefix. Should this check instead be for a function with the attribute “tf.entry_function” and should this be patched like this or better with a common utility to update all passes with such checks?

-  auto main = module.lookupSymbol<FuncOp>("main");
-  if (!main) return;
-  if (failed(DecomposeTensorListOps(&main.front(), module))) {
-    signalPassFailure();
+  for (auto func_op : module.getOps<FuncOp>()) {
+    // Just run on the entry function.
+    if (!func_op->getAttr("tf.entry_function") && func_op.getName() != "main")
+      continue;
+    if (failed(DecomposeTensorListOps(&func_op.front(), module))) {
+      signalPassFailure();
+    }
+    break;
   }

Related to this are also several instances of “main” and “tf.entry_function” hardcoded in “transforms/” and “translate/”.

We likely should provide a helper for this instead of a raw loop.
Also it isn’t clear to me what a public function (from an MLIR symbol visibility point of view) that isn’t an entry function would mean? And if so why not just filter on public ones?

It’s possible that passes generated additional functions but missed marking them as private. If the canonical form is one where there is a single entry function marked in a defined way, these passes could be called on that one. If not, they could be called on all those that are visible. Another alternative which is least surprising I feel is to call it on everything. The current behavior isn’t really correct or in line with any of the standard forms being used.

@Mehdi_AMINI Anything further on the course of action here?

We haven’t invested further into figuring anything out on this right now.

The private/public and entry function concept is a bit confusing here. The function that is called main is the function corresponding to the Graph (so during conversion one has 1 Graph with a function library with multiple functions). For an execution of a model (which these passes were developed for and run) one therefore has this situation. The place where we have the clearest indication of public or private is during SavedModel conversion. But even there in the workflows supported (TFlite converter and TFRT serving conversion) we have single entry point AFAIK. Could you show the python code corresponding to the tf.function example?

If you just take the simplest example like this one,

@tf.function(
    input_signature=(
        tf.TensorSpec(shape=(M, K), dtype=tf.float32),
        tf.TensorSpec(shape=(K, N), dtype=tf.float32),
    )
)
def matmul(lhs, rhs):
    return tf.matmul(lhs, rhs)

and use tensorflow.python.pywrap_mlir to do a:

 import_function(
        func.get_concrete_function(), pass_pipeline="", show_debug_info=False)

the MLIR you get won’t have a “main” but just something like matmul in its name. Most of the tensorflow MLIR transforms would just end up being no-ops on them.

That is a experimental/testing API that is not along any execution or conversion paths. It is something that shows a part of the import (this shows how a function in the flib would be imported) and can be used for “visualization”. There are multiple ways to indicate entry, the convention we follow is to call it main, in particular as during execution we have a nameless Graph at the point where we import during execution.

Oh, can I know what the recommended API method to import the decorated tf.function into MLIR then would be? I think your note then confirms that the canonical and expected form is one where the entry function is named “main” – that was my original question.

We could change this import to generate a main function instead?

But in general it’s not clear to me why we don’t use “public” for most passes / why do we filter on “main”?