Extending the tf.Tensor class

Hey there!

I want to extend the tf.Tensor class, but neither of the following options work. I do want to specifically extend it though, and not store a tf.tensor instance as an attribute in my own class!

  1. option: extend tf.Tensor:
class MyTFTensor(tf.Tensor):
    
    @classmethod
    def _from_native(cls, value: tf.Tensor):
        value.__class__ = cls
        return value

y = MyTFTensor._from_native(value=tf.zeros((3, 224, 224))

Fails with:

/var/folders/kb/yxxdttyj4qzcm447np5p22kw0000gp/T/ipykernel_94703/515175960.py in _from_native(cls, value)
      4     @classmethod
      5     def _from_native(cls, value: tf.Tensor):
----> 6         value.__class__ = cls
      7         return value

TypeError: __class__ assignment: 'MyTFTensor' object layout differs from 'tensorflow.python.framework.ops.EagerTensor'
  1. Option: extend EagerTensor
from tensorflow.python.framework.ops import EagerTensor
class MyTFTensor(EagerTensor):
    
    @classmethod
    def _from_native(cls, value: tf.Tensor):
        value.__class__ = cls
        return value

Fails with:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/var/folders/kb/yxxdttyj4qzcm447np5p22kw0000gp/T/ipykernel_94703/3632871733.py in <cell line: 2>()
      1 from tensorflow.python.framework.ops import EagerTensor
----> 2 class MyTFTensor(EagerTensor):
      3 
      4     @classmethod
      5     def _from_native(cls, value: tf.Tensor):

TypeError: type 'tensorflow.python.framework.ops.EagerTensor' is not an acceptable base type

Does anyone have a solution for this?

Hi @anon25443

Welcome to the TensorFlow Forum!

You can not subclass tf.tensor which also causes the above error. This is not recommended and often not feasible due to internal complexities within TensorFlow.

Inheriting directly from EagerTensor is not allowed because it’s not designed to be a base class for user-defined extensions. Thank you.