MLIR optimization failure with DenseHashTable

Hi

We are using tf.lookup.experimental.DenseHashTable withing a tf.map_fn to perform specialised sorting of items within rows of a RaggedTensor. When loading the SavedModel with Tensorflow Serving, we get the following MLIR failure message:

error: 'tfg.While' op body function argument #6 type 'tensor<!tf_type.resource<tensor<!tf_type.string>>>' is not compatible with corresponding operand type: 'tensor<!tf_type.resource<tensor<!tf_type.string>, tensor<i32>>>'
2024-05-15 06:42:15.209491: E external/org_tensorflow/tensorflow/core/grappler/optimizers/meta_optimizer.cc:954] tfg_optimizer{tfg-consolidate-attrs,tfg-toposort,tfg-shape-inference{graph-version=0},tfg-prepare-attrs-export} failed: INVALID_ARGUMENT: MLIR Graph Optimizer failed:

The tfg.While is generated from the map_fn. Argument #6 is the DenseHashTable, which has tf.string keys and tf.int32 values.

The code:

@tf.function
def sort_rail_horizontally_hashmap_map_fn(sorted_list: tf.Tensor, rails_to_sort: tf.RaggedTensor):

    @tf.function
    def reorder_rail(rail):
        lookup_indexes = lookup_table.lookup(rail)
        match_mask = ~tf.equal(lookup_indexes, -1)
        match_indexes = tf.where(condition=match_mask)
        extracted_match_indexes = tf.gather(
            lookup_indexes, match_indexes[:, 0], name="gather_extracted_match_indexes"
        )
        extracted_sorted_list = tf.gather(
            sorted_list, extracted_match_indexes, name="gather_extracted_sorted_list"
        )
        sorted_match_indexes = tf.argsort(extracted_match_indexes)
        reordered_extracted_sorted_list = tf.gather(
            extracted_sorted_list,
            sorted_match_indexes,
            name="gather_reordered_extracted_sorted_list",
        )
        composite_rail = tf.tensor_scatter_nd_update(
            tensor=rail,
            indices=match_indexes,
            updates=reordered_extracted_sorted_list,
            name="composite_rail_scatter",
        )
        return composite_rail

    lookup_table = tf.lookup.experimental.DenseHashTable(
        key_dtype=tf.string,
        value_dtype=tf.int32,
        default_value=-1,
        empty_key="$",
        deleted_key="£",
    )
    lookup_table.insert(
        sorted_list, 
        tf.range(0, tf.size(sorted_list), dtype=tf.int32),
        name="lookup_table_insert"
        )

    ragged_rails = tf.map_fn(
        reorder_rail,
        rails_to_sort,
        parallel_iterations=50,
        fn_output_signature=tf.RaggedTensorSpec(shape=[None], dtype=tf.string),
        name="rails_to_sort_map_fn",
    )

    return ragged_rails

The While node looks like this:

node_def {
      name: "rails_to_sort_map_fn/while"
      op: "While"
      input: "rails_to_sort_map_fn/while/loop_counter:output:0"
      input: "rails_to_sort_map_fn/strided_slice:output:0"
      input: "rails_to_sort_map_fn/Const:output:0"
      input: "rails_to_sort_map_fn/TensorArrayV2_1:handle:0"
      input: "rails_to_sort_map_fn/strided_slice:output:0"
      input: "rails_to_sort_map_fn/TensorArrayUnstack/TensorListFromTensor:output_handle:0"
      input: "MutableDenseHashTable:table_handle:0"
      input: "default_value:output:0"
      input: "sorted_list"
      input: "^lookup_table_insert/LookupTableInsertV2"
      attr {
        key: "T"
        value {
          list {
            type: DT_INT32
            type: DT_INT32
            type: DT_INT32
            type: DT_VARIANT
            type: DT_INT32
            type: DT_VARIANT
            type: DT_RESOURCE
            type: DT_INT32
            type: DT_STRING
          }
        }
      }
      attr {
        key: "_lower_using_switch_merge"
        value {
          b: true
        }
      }
      attr {
        key: "_num_original_outputs"
        value {
          i: 9
        }
      }
      attr {
        key: "_read_only_resource_inputs"
        value {
          list {
          }
        }
      }
      attr {
        key: "body"
        value {
          func {
            name: "rails_to_sort_map_fn_while_body_18098"
          }
        }
      }
      attr {
        key: "cond"
        value {
          func {
            name: "rails_to_sort_map_fn_while_cond_18097"
          }
        }
      }
      attr {
        key: "output_shapes"
        value {
          list {
            shape {
            }
            shape {
            }
            shape {
            }
            shape {
            }
            shape {
            }
            shape {
            }
            shape {
            }
            shape {
            }
            shape {
              dim {
                size: 828
              }
            }
          }
        }
      }
      attr {
        key: "parallel_iterations"
        value {
          i: 50
        }
      }
    }

The graph executes as expected, and we guess that maybe MLIR does not yet support references to DenseHashTable or it’s an MLIR bug?

Our primary concern is what is the effect of the failure. Does it stop all graph optimization on the target server, or only that node of the graph?

Thanks

Adrian

Hey,

It’s not an MLIR issue rather than a TFG one (the former is used to build the latter, but the latter is a different group of people etc, mostly mentioning this as this also determines who triage and fix).

If memory serves (and it’s been a few years) on failure of this pass, it just becomes a NOP. So it doesn’t stop or affect any other optimization passes.

The error message appears to say that you have an invalid input which just happens to fail here first. Now this could also be as the type system is more precise here and possibly the runtime may have led this slide. In particular, the while loop is supposed to operate as a fixed point wrt type, while that input argument is changing from a single to “struct” tensor. Possibly just manually casting or setting the type should suffice.

– Jacques

@Jacques_Pienaar Hi Jacques

Thank you for the reply. Trying to understand how your suggestion of a cast can be achieved. We can’t cast the DenseHashTable itself to int32. I did try that, even though logically it doesn’t make sense, and get this error:

TypeError: Failed to convert elements of <tensorflow.python.ops.lookup_ops.DenseHashTable object at 0x18b7b4b80> to Tensor. Consider casting elements to a supported type. See https://www.tensorflow.org/api_docs/python/tf/dtypes for supported TF dtypes.

I can’t see a way of restructuring the code without dropping tf.map_fn because the function needs to perform a lookup using the rail input.

Adrian

Yeah there are a lot of layers here :slight_smile:

Let me start with what the error reporting is saying. For the error its not about casting to an int32, but its a tensor tuple.

op body function argument
'tensor<!tf_type.resource<tensor<!tf_type.string>>>'
is not compatible with corresponding operand type:
'tensor<!tf_type.resource<tensor<!tf_type.string>, tensor<i32>>>'

So we have a resource with single tensor vs a resource with two tensors. Conceptually

struct X { tensor<string> s; }

vs

struct Y { tensor<string> s; tensor<int32_t> i; }

You have a resource of Y where one of X is expected. And so it would seem if one did a load from the resource before the loop/slice or the generated function body expected that resource type, then it is close but not a resource of. So this is about that is that this about MutableDenseHashTable:table_handle:0’s type (which is due to how its created) and the function body. From your example, the resource is a hash table, but the body just expects the keys (well it expects a resource of the keys) I don’t know why though.

Now a simple trick could be to have TFG not verify this and hope for the best :slight_smile: I’m not saying this as facetiously as it may sound - there was the whole FullType effort which would have greatly improved this but thats no longer planned, so this should not result in any issue today esp as this would be most useful for XLA where the TF2XLA bridge needs to do the same verification later but as XLA doesn’t support strings, this won’t make it there.

If you look in tensorflow/core/protobuf/rewriter_config.proto, you’ll see a flag to disable TFG. I’d start there and verify that the output graphs are different.

The error flagged is one where I’d expect the runtime to complain too (it feels wrong) and in which case disabling TFG import should result in a failure later. In which case it bubbles up the triage needed to higher in the stack and more TF frontend/data folks.

(Oh and my flair, or whatever one calls it here, is not accurate I’m no longer on the TF team)