TF 2.7.0 dispatch_for_api and tf.matmul

Hello! I’m trying to replace all usage of matrix multiplication. For the purposes of this simple example and verification, I’m just implementing something that will result in zeroes whenever tf.matmul is used.

I’m currently using TF 2.7.0 on MacOS, Python 3.9.0 for development. I’ve based the below code off of Extension types  |  TensorFlow Core

import tensorflow as tf
import numpy as np
from typing import Tuple, List, Mapping, Union, Optional

class MyCustomTensor(tf.experimental.BatchableExtensionType):
    # Simple custom tensor, does not do anything special
    __name__ = 'replace.tf.MyCustomTensor'

    values: tf.Tensor

    shape = property(lambda self: self.values.shape)
    dtype = property(lambda self: self.values.dtype)


    class Spec:
        def __init__(self, shape, dtype=tf.float32):
            self.values = tf.TensorSpec(shape, dtype)

        shape = property(lambda self: self.values.shape)
        dtype = property(lambda self: self.values.dtype)

        def with_shape(self):
            return MyCustomTensor.Spec(tf.TensorSpec(shape, self.values.dtype))

def convert_to_custom_tensor(x):
    if isinstance(x, MyCustomTensor):
        return x
    else:
        return MyCustomTensor(x)


@tf.experimental.dispatch_for_unary_elementwise_apis(MyCustomTensor)
def unary_elementwise_op_handler(op, x):
    return MyCustomTensor(op(x.values))


@tf.experimental.dispatch_for_binary_elementwise_apis(Union[MyCustomTensor, tf.Tensor], Union[MyCustomTensor, tf.Tensor])
def binary_elementwise_op_handler(op, x, y):
    x = convert_to_custom_tensor(x)
    y = convert_to_custom_tensor(y)
    return MyCustomTensor(op(x.values, y.values))


@tf.experimental.dispatch_for_api(tf.matmul)
def custom_matmul(a: MyCustomTensor, b,
                  transpose_a=False, transpose_b=False,
                  adjoint_a=False, adjoint_b=False,
                  a_is_sparse=False, b_is_sparse=False,
                  output_type=None, name=None):

    if isinstance(a, MyCustomTensor):
        a = tf.zeros(a.shape)
    if isinstance(b, MyCustomTensor):
        b = tf.zeros(b.shape)

    tf.print("Matmul replaced!", output_stream=sys.stdout)
    return tf.matmul(a, b, transpose_a, transpose_b, adjoint_a,
                     adjoint_b, a_is_sparse, b_is_sparse, output_type)

I then create a simple model:

dense_input_spec = MyCustomTensor.Spec([1, 2], tf.float32)

dense_model = tf.keras.Sequential([
    tf.keras.layers.Input(type_spec=dense_input_spec),
    tf.keras.layers.Dense(16, activation="relu", use_bias=False),
    tf.keras.layers.Dense(1, use_bias=False)])

dense_model(MyCustomTensor(np.ones((1,2))))

dense_model(np.ones((1,2))) gives me a nonzero value whereas dense_model(MyCustomTensor(np.ones((1,2)))) gives me 0 as expected. Awesome!

However, when I try:

conv_input_spec = MyCustomTensor.Spec([1,224, 224, 3], tf.float32)
conv_model = tf.keras.Sequential([
    tf.keras.layers.Input(type_spec=conv_input_spec),
    tf.keras.layers.Conv2D(3, 3, use_bias=False)])

I get the following error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/var/folders/d7/c4dm46h53_d9vhx9_q_y12fm0000gn/T/ipykernel_9694/354279810.py in <module>
      1 conv_input_spec = MyCustomTensor.Spec([1,224, 224, 3], tf.float32)
----> 2 conv_model = tf.keras.Sequential([
      3     tf.keras.layers.Input(type_spec=conv_input_spec),
      4     tf.keras.layers.Conv2D(3, 3, use_bias=False),
      5 ]

~/.pyenv/versions/3.9.0/envs/idiom-ml-tf27/lib/python3.9/site-packages/tensorflow/python/training/tracking/base.py in _method_wrapper(self, *args, **kwargs)
    528     self._self_setattr_tracking = False  # pylint: disable=protected-access
    529     try:
--> 530       result = method(self, *args, **kwargs)
    531     finally:
    532       self._self_setattr_tracking = previous_value  # pylint: disable=protected-access

~/.pyenv/versions/3.9.0/envs/idiom-ml-tf27/lib/python3.9/site-packages/keras/utils/traceback_utils.py in error_handler(*args, **kwargs)
     65     except Exception as e:  # pylint: disable=broad-except
     66       filtered_tb = _process_traceback_frames(e.__traceback__)
---> 67       raise e.with_traceback(filtered_tb) from None
     68     finally:
     69       del filtered_tb

~/.pyenv/versions/3.9.0/envs/idiom-ml-tf27/lib/python3.9/site-packages/tensorflow/python/framework/tensor_util.py in make_tensor_proto(values, dtype, shape, verify_shape, allow_broadcast)
    547       str_values = [compat.as_bytes(x) for x in proto_values]
    548     except TypeError:
--> 549       raise TypeError(f"Failed to convert elements of {values} to Tensor. "
    550                       "Consider casting elements to a supported type. See "
    551                       "https://www.tensorflow.org/api_docs/python/tf/dtypes "

TypeError: Exception encountered when calling layer "conv2d" (type Conv2D).

Failed to convert elements of MyCustomTensor(values=<tf.Tensor 'Placeholder:0' shape=(1, 224, 224, 3) dtype=float32>) to Tensor. Consider casting elements to a supported type. See https://www.tensorflow.org/api_docs/python/tf/dtypes for supported TF dtypes.

Call arguments received:
  • inputs=MyCustomTensor(values=<tf.Tensor 'Placeholder:0' shape=(1, 224, 224, 3) dtype=float32>)

Turns out this requires dispatching tf.nn.convolution, something like:

def custom_convolution(input: MyCustomTensor, 
                       filters,
                       strides=None,
                       padding="VALID",
                       data_format=None,
                       dilations=None,
                       name=None):
    tf.print("Conv replaced!")
    input = tf.zeros(input.shape)
    return tf.nn.convolution(input, 
                            filters,
                            strides=strides,
                            padding=padding,
                            data_format=data_format,
                            dilations=dilations,
                            name=name)

(Really silly and simple, I know, but this was just to test replacing ops) Then I get expected results of zeroes. Looking through the source code, it seems that Conv operations are dispatched to C++ code? Do I really need to dispatch for all Conv ops? Is there any way to replace all matrix multiplication for all ops?

Thank you in advance for any pointers!