Not able to lower tf.sets.intersection to HLO

My python code snippet is like:
import os
import numpy as np
import tensorflow as tf
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.platform import test
from tensorflow.python.framework import config
import tensorflow.compat.v1 as tf
config.enable_mlir_bridge()
tf.config.experimental.enable_mlir_bridge()

class CustomModule(tf.Module):

def init(self):
super(CustomModule, self).init()
self.condition = tf.Variable(np.array([[True, False, False],[False, True, False],[True, True, True]]), dtype = tf.bool)
self.x = tf.Variable(np.array([[1, 2, 3],[4, 5, 6],[7, 8, 9]]), dtype = tf.int32)
self.y =tf.Variable(np.array([[11, 12, 13],[14, 15, 16],[17, 18, 19]]), dtype = tf.int32)

@tf.function
def call(self, x):
r = tf.where(self.condition, self.x, self.y)
m= tf.where(self.condition, self.x, self.y)

c=tf.sets.intersection(tf.expand_dims(r, 0),tf.expand_dims(m, 0))

return c

module = CustomModule()

module_with_signature_path = os.path.join("/data/aruna/tf_ops", ‘sets_intersection’)
call = module.call.get_concrete_function(tf.TensorSpec(shape=(), dtype=tf.int32))
signatures = {‘predict’: call}
tf.saved_model.save(module, module_with_signature_path, signatures=call)
print(‘Saving model…’)

if name == ‘main’:
test.main()

I ran this python code and got saved_model.pb.
Then I used following commands:
tensorflow/compiler/mlir/tf-mlir-translate --savedmodel-objectgraph-to-mlir --tf-savedmodel-exported-names=predict -tf-enable-shape-inference-on-import=true $PWD -o sample.mlir
tensorflow/compiler/mlir/tf-opt --tf-executor-to-functional-conversion --tf-shape-inference -xla-legalize-tf --print-ir-before-all sample.mlir

TF dialect looks like:
// -----// IR Dump Before LegalizeTF //----- //
builtin.func private @__inference___call___750(%arg0: tensor {tf._user_specified_name = “x”}, %arg1: tensor<!tf_type.resource>, %arg2: tensor<!tf_type.resource>, %arg3: tensor<!tf_type.resource>) → (tensor<?x3xi64>, tensor<?xi32>, tensor<3xi64>) attributes {tf._construction_context = “kEagerRuntime”, tf._input_shapes = [#tf_type.shape<>, #tf_type.shape<>, #tf_type.shape<>, #tf_type.shape<>], tf.signature.is_stateful} {
%cst = “tf.Const”() {device = “”, value = dense<0> : tensor} : () → tensor
%cst_0 = “tf.Const”() {device = “”, value = dense<0> : tensor}2021-09-08 09:56:50.579733: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: SSE3 SSE4.1 SSE4.2 AVX AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
: () → tensor
%0 = “tf.ReadVariableOp”(%arg2) {device = “”} : (tensor<!tf_type.resource>) → tensor<3x3xi32>
%1 = “tf.ReadVariableOp”(%arg2) {device = “”} : (tensor<!tf_type.resource>) → tensor<3x3xi32>
%2 = “tf.ReadVariableOp”(%arg3) {device = “”} : (tensor<!tf_type.resource>) → tensor<3x3xi32>
%3 = “tf.ReadVariableOp”(%arg3) {device = “”} : (tensor<!tf_type.resource>) → tensor<3x3xi32>
%4 = “tf.ReadVariableOp”(%arg1) {device = “”} : (tensor<!tf_type.resource>) → tensor<3x3xi1>
%5 = “tf.Select”(%4, %0, %2) {device = “”} : (tensor<3x3xi1>, tensor<3x3xi32>, tensor<3x3xi32>) → tensor<3x3xi32>
%6 = “tf.ExpandDims”(%5, %cst) {device = “”} : (tensor<3x3xi32>, tensor) → tensor<1x3x3xi32>
%7 = “tf.ReadVariableOp”(%arg1) {device = “”} : (tensor<!tf_type.resource>) → tensor<3x3xi1>
“tf.NoOp”() {_acd_function_control_output = true, device = “”} : () → ()
%8 = “tf.Select”(%7, %1, %3) {device = “”} : (tensor<3x3xi1>, tensor<3x3xi32>, tensor<3x3xi32>) → tensor<3x3xi32>
%9 = “tf.ExpandDims”(%8, %cst_0) {device = “”} : (tensor<3x3xi32>, tensor) → tensor<1x3x3xi32>
%10:3 = “tf.DenseToDenseSetOperation”(%6, %9) {T = i32, device = “”, set_operation = “intersection”, validate_indices = true} : (tensor<1x3x3xi32>, tensor<1x3x3xi32>) → (tensor<?x3xi64>, tensor<?xi32>, tensor<3xi64>)
%11 = “tf.Identity”(%10#0) {device = “”} : (tensor<?x3xi64>) → tensor<?x3xi64>
%12 = “tf.Identity”(%10#1) {device = “”} : (tensor<?xi32>) → tensor<?xi32>
%13 = “tf.Identity”(%10#2) {device = “”} : (tensor<3xi64>) → tensor<3xi64>
return %11, %12, %13 : tensor<?x3xi64>, tensor<?xi32>, tensor<3xi64>
}

Error is:
sample.mlir:5:3: error: The following operations cannot be legalized: tf.DenseToDenseSetOperation (count: 1); tf.NoOp (count: 1); tf.ReadVariableOp (count: 6). These legalization failure(s) may be due to missing TF to HLO lowerings and/or unsupported attributes, etc.
builtin.func private @__inference___call___340(%arg0: tensor {tf._user_specified_name = “x”}, %arg1: tensor<!tf_type.resource>, %arg2: tensor<!tf_type.resource>, %arg3: tensor<!tf_type.resource>) → (tensor<?x3xi64>, tensor<?xi32>, tensor<3xi64>) attributes {tf._construction_context = “kEagerRuntime”, tf._input_shapes = [#tf_type.shape<>, #tf_type.shape<>, #tf_type.shape<>, #tf_type.shape<>], tf.signature.is_stateful} {
^
sample.mlir:5:3: error: Emitting more detail about one op that failed to legalize…
builtin.func private @__inference___call___340(%arg0: tensor {tf._user_specified_name = “x”}, %arg1: tensor<!tf_type.resource>, %arg2: tensor<!tf_type.resource>, %arg3: tensor<!tf_type.resource>) → (tensor<?x3xi64>, tensor<?xi32>, tensor<3xi64>) attributes {tf._construction_context = “kEagerRuntime”, tf._input_shapes = [#tf_type.shape<>, #tf_type.shape<>, #tf_type.shape<>, #tf_type.shape<>], tf.signature.is_stateful} {
^
sample.mlir:20:61: error: ‘tf.DenseToDenseSetOperation’ op is not legalizable
%outputs_23:3, %control_24 = tf_executor.island wraps “tf.DenseToDenseSetOperation”(%outputs_14, %outputs_21) {T = i32, device = “”, set_operation = “intersection”, validate_indices = true} : (tensor<1x3x3xi32>, tensor<1x3x3xi32>) → (tensor<?x3xi64>, tensor<?xi32>, tensor<3xi64>)

tf.CropAndResize
tf.StridedSlice
tf.Unique
tf.Where
tf.SparseToDense
tf.NonMaxSuppressionV4
tf.TensorListFromTensor
tf.TensorListGetItem
tf.DenseToDenseSetOperation
tf.TensorListReserve
tf.TensorListSetItem
tf.TensorListStack
tf.TopKV2

For the above ops also I am getting same error while lowering these ops to HLO