Extending the tf.Tensor class

Hey there!

I want to extend the tf.Tensor class, but neither of the following options work. I do want to specifically extend it though, and not store a tf.tensor instance as an attribute in my own class!

  1. option: extend tf.Tensor:
class MyTFTensor(tf.Tensor):
    
    @classmethod
    def _from_native(cls, value: tf.Tensor):
        value.__class__ = cls
        return value

y = MyTFTensor._from_native(value=tf.zeros((3, 224, 224))

Fails with:

/var/folders/kb/yxxdttyj4qzcm447np5p22kw0000gp/T/ipykernel_94703/515175960.py in _from_native(cls, value)
      4     @classmethod
      5     def _from_native(cls, value: tf.Tensor):
----> 6         value.__class__ = cls
      7         return value

TypeError: __class__ assignment: 'MyTFTensor' object layout differs from 'tensorflow.python.framework.ops.EagerTensor'
  1. Option: extend EagerTensor
from tensorflow.python.framework.ops import EagerTensor
class MyTFTensor(EagerTensor):
    
    @classmethod
    def _from_native(cls, value: tf.Tensor):
        value.__class__ = cls
        return value

Fails with:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/var/folders/kb/yxxdttyj4qzcm447np5p22kw0000gp/T/ipykernel_94703/3632871733.py in <cell line: 2>()
      1 from tensorflow.python.framework.ops import EagerTensor
----> 2 class MyTFTensor(EagerTensor):
      3 
      4     @classmethod
      5     def _from_native(cls, value: tf.Tensor):

TypeError: type 'tensorflow.python.framework.ops.EagerTensor' is not an acceptable base type

Does anyone have a solution for this?