Implementing unet with tensorflow subclassing API does not work

I am trying to implement Unet with TensorFlow subclassing API and something does not seem to work properly, and I get the following error:

OperatorNotAllowedInGraphError: iterating over `tf.Tensor` is not allowed in Graph execution. Use Eager execution or decorate this function with @tf.function.

Furthermore, I am uncertain if I have correctly implemented the logic inside the call() function. Any help to correct my mistakes would be much appreciated.

Here I am attaching the full copy of the implementation and the error tracks:

Code Implementation:


from functools import partial

keras.backend.clear_session()
tf.random.set_seed(42)
np.random.seed(42)

conv2d = partial(keras.layers.Conv2D, kernel_size = 3, 
                    padding = 'SAME',
                    kernel_initializer = 'he_normal', 
                    use_bias = False)

conv2dtranspose = partial(keras.layers.Conv2DTranspose, 
                          kernel_size = 2, strides = 2, 
                          padding = 'SAME')

class encoder(keras.layers.Layer):
    def __init__(self, filters, **kwargs):
        super(encoder, self).__init__(**kwargs)
        
        self.convs = [
            conv2d(filters),
            keras.layers.BatchNormalization(),
            keras.layers.Activation('relu'),
            conv2d(filters),
            keras.layers.BatchNormalization(),
            keras.layers.Activation('relu')
                         ]
        
    def call(self, inputs):
        Z = inputs
        for layer in self.convs:
            Z = layer(Z)
        return Z


class UNet(keras.models.Model):
    def __init__(self, filters, inputs_shape = [128, 128, 1], **kwargs):
        super(UNet, self).__init__(**kwargs)
        self.filters = filters
        self.inputs = keras.layers.Input(shape = inputs_shape)
        self.maxpool2d = keras.layers.MaxPool2D(pool_size = (2, 2), strides = 2)
        self.conv2dtranspose = conv2dtranspose
        self.concat = keras.layers.Concatenate()
        
        
    def call(self, inputs):
        skips = {}

        Z, inpt = inputs

        #implementing encoder path
        for fId in range(len(self.filters)):
            Z = encoder(filters = self.filters[fId])(Z)
            if fId < len(self.filters) - 1:
                skips[fId] = Z
                Z = self.maxpool2d(Z)

        #implementing decoder path
        for fId in reversed(range(len(self.filters) - 1)):
            Z = self.conv2dtranspose(self.filters[fId])(Z)
            Z = self.concat([Z, skips[fId]])
            Z = encoder(self.filters[::-1][fId])(Z)
        
        output =  keras.layers.Conv2D(1, kernel_size = 1, activation = 'sigmoid')(Z)    
        return keras.Model(inputs = [inpt], outputs = [output])

    
filters = [64, 128, 256, 512]

inpt = keras.layers.Input(shape = [128, 128, 1])
model = UNet(filters = filters)(inpt)

#Generating some test data

x = tf.random.normal(shape = (10, 128, 128, 1))
y = tf.random.normal(shape = (10, 128, 128, 1))


model.compile(loss = 'binary_crossentropy', optimizer = keras.optimizers.SGD(), metrics = ['accuracy'])
model.fit(x, y, epochs = 3)

Error Tracks:


WARNING:tensorflow:AutoGraph could not transform <bound method UNet.call of <__main__.UNet object at 0x2930b3d30>> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: module 'gast' has no attribute 'Index'
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
WARNING: AutoGraph could not transform <bound method UNet.call of <__main__.UNet object at 0x2930b3d30>> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: module 'gast' has no attribute 'Index'
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py in converted_call(f, args, kwargs, caller_fn_scope, options)
    446     program_ctx = converter.ProgramContext(options=options)
--> 447     converted_f = _convert_actual(target_entity, program_ctx)
    448     if logging.has_verbosity(2):

~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py in _convert_actual(entity, program_ctx)
    283 
--> 284   transformed, module, source_map = _TRANSPILER.transform(entity, program_ctx)
    285 

~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/autograph/pyct/transpiler.py in transform(self, obj, user_context)
    285     if inspect.isfunction(obj) or inspect.ismethod(obj):
--> 286       return self.transform_function(obj, user_context)
    287 

~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/autograph/pyct/transpiler.py in transform_function(self, fn, user_context)
    469           # TODO(mdan): Confusing overloading pattern. Fix.
--> 470           nodes, ctx = super(PyToPy, self).transform_function(fn, user_context)
    471 

~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/autograph/pyct/transpiler.py in transform_function(self, fn, user_context)
    362     node = self._erase_arg_defaults(node)
--> 363     result = self.transform_ast(node, context)
    364 

~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py in transform_ast(self, node, ctx)
    251     unsupported_features_checker.verify(node)
--> 252     node = self.initial_analysis(node, ctx)
    253 

~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py in initial_analysis(self, node, ctx)
    238     graphs = cfg.build(node)
--> 239     node = qual_names.resolve(node)
    240     node = activity.resolve(node, ctx, None)

~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/autograph/pyct/qual_names.py in resolve(node)
    251 def resolve(node):
--> 252   return QnResolver().visit(node)
    253 

~/miniforge3/envs/mlm1-engine/lib/python3.8/ast.py in visit(self, node)
    370         visitor = getattr(self, method, self.generic_visit)
--> 371         return visitor(node)
    372 

~/miniforge3/envs/mlm1-engine/lib/python3.8/ast.py in generic_visit(self, node)
    446                     if isinstance(value, AST):
--> 447                         value = self.visit(value)
    448                         if value is None:

~/miniforge3/envs/mlm1-engine/lib/python3.8/ast.py in visit(self, node)
    370         visitor = getattr(self, method, self.generic_visit)
--> 371         return visitor(node)
    372 

~/miniforge3/envs/mlm1-engine/lib/python3.8/ast.py in generic_visit(self, node)
    446                     if isinstance(value, AST):
--> 447                         value = self.visit(value)
    448                         if value is None:

~/miniforge3/envs/mlm1-engine/lib/python3.8/ast.py in visit(self, node)
    370         visitor = getattr(self, method, self.generic_visit)
--> 371         return visitor(node)
    372 

~/miniforge3/envs/mlm1-engine/lib/python3.8/ast.py in generic_visit(self, node)
    455             elif isinstance(old_value, AST):
--> 456                 new_node = self.visit(old_value)
    457                 if new_node is None:

~/miniforge3/envs/mlm1-engine/lib/python3.8/ast.py in visit(self, node)
    370         visitor = getattr(self, method, self.generic_visit)
--> 371         return visitor(node)
    372 

~/miniforge3/envs/mlm1-engine/lib/python3.8/ast.py in generic_visit(self, node)
    455             elif isinstance(old_value, AST):
--> 456                 new_node = self.visit(old_value)
    457                 if new_node is None:

~/miniforge3/envs/mlm1-engine/lib/python3.8/ast.py in visit(self, node)
    370         visitor = getattr(self, method, self.generic_visit)
--> 371         return visitor(node)
    372 

~/miniforge3/envs/mlm1-engine/lib/python3.8/ast.py in generic_visit(self, node)
    446                     if isinstance(value, AST):
--> 447                         value = self.visit(value)
    448                         if value is None:

~/miniforge3/envs/mlm1-engine/lib/python3.8/ast.py in visit(self, node)
    370         visitor = getattr(self, method, self.generic_visit)
--> 371         return visitor(node)
    372 

~/miniforge3/envs/mlm1-engine/lib/python3.8/ast.py in generic_visit(self, node)
    455             elif isinstance(old_value, AST):
--> 456                 new_node = self.visit(old_value)
    457                 if new_node is None:

~/miniforge3/envs/mlm1-engine/lib/python3.8/ast.py in visit(self, node)
    370         visitor = getattr(self, method, self.generic_visit)
--> 371         return visitor(node)
    372 

~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/autograph/pyct/qual_names.py in visit_Subscript(self, node)
    231     s = node.slice
--> 232     if not isinstance(s, gast.Index):
    233       # TODO(mdan): Support range and multi-dimensional indices.

AttributeError: module 'gast' has no attribute 'Index'

During handling of the above exception, another exception occurred:

OperatorNotAllowedInGraphError            Traceback (most recent call last)
<ipython-input-449-e6f92329b0db> in <module>
      2 
      3 inpt = keras.layers.Input(shape = [128, 128, 1])
----> 4 model = UNet(filters = filters)(inpt)
      5 
      6 #Generating some test data

~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py in __call__(self, *args, **kwargs)
    944     # >> model = tf.keras.Model(inputs, outputs)
    945     if _in_functional_construction_mode(self, inputs, args, kwargs, input_list):
--> 946       return self._functional_construction_call(inputs, args, kwargs,
    947                                                 input_list)
    948 

~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py in _functional_construction_call(self, inputs, args, kwargs, input_list)
   1083           layer=self, inputs=inputs, build_graph=True, training=training_value):
   1084         # Check input assumptions set after layer building, e.g. input shape.
-> 1085         outputs = self._keras_tensor_symbolic_call(
   1086             inputs, input_masks, args, kwargs)
   1087 

~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py in _keras_tensor_symbolic_call(self, inputs, input_masks, args, kwargs)
    815       return nest.map_structure(keras_tensor.KerasTensor, output_signature)
    816     else:
--> 817       return self._infer_output_signature(inputs, args, kwargs, input_masks)
    818 
    819   def _infer_output_signature(self, inputs, args, kwargs, input_masks):

~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py in _infer_output_signature(self, inputs, args, kwargs, input_masks)
    856           # TODO(kaftan): do we maybe_build here, or have we already done it?
    857           self._maybe_build(inputs)
--> 858           outputs = call_fn(inputs, *args, **kwargs)
    859 
    860         self._handle_activity_regularization(inputs, outputs)

~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py in wrapper(*args, **kwargs)
    665       try:
    666         with conversion_ctx:
--> 667           return converted_call(f, args, kwargs, options=options)
    668       except Exception as e:  # pylint:disable=broad-except
    669         if hasattr(e, 'ag_error_metadata'):

~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py in converted_call(f, args, kwargs, caller_fn_scope, options)
    452     if is_autograph_strict_conversion_mode():
    453       raise
--> 454     return _fall_back_unconverted(f, args, kwargs, options, e)
    455 
    456   with StackTraceMapper(converted_f), tf_stack.CurrentModuleFilter():

~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py in _fall_back_unconverted(f, args, kwargs, options, exc)
    499     logging.warn(warning_template, f, file_bug_message, exc)
    500 
--> 501   return _call_unconverted(f, args, kwargs, options)
    502 
    503 

~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py in _call_unconverted(f, args, kwargs, options, update_cache)
    476 
    477   if kwargs is not None:
--> 478     return f(*args, **kwargs)
    479   return f(*args)
    480 

<ipython-input-448-ce9f55fd84b1> in call(self, inputs)
     49         skips = {}
     50 
---> 51         Z, inpt = inputs
     52 
     53         #implementing encoder path

~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/framework/ops.py in __iter__(self)
    503   def __iter__(self):
    504     if not context.executing_eagerly():
--> 505       self._disallow_iteration()
    506 
    507     shape = self._shape_tuple()

~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/framework/ops.py in _disallow_iteration(self)
    499     else:
    500       # Default: V1-style Graph execution.
--> 501       self._disallow_in_graph_mode("iterating over `tf.Tensor`")
    502 
    503   def __iter__(self):

~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/framework/ops.py in _disallow_in_graph_mode(self, task)
    477 
    478   def _disallow_in_graph_mode(self, task):
--> 479     raise errors.OperatorNotAllowedInGraphError(
    480         "{} is not allowed in Graph execution. Use Eager execution or decorate"
    481         " this function with @tf.function.".format(task))

OperatorNotAllowedInGraphError: iterating over `tf.Tensor` is not allowed in Graph execution. Use Eager execution or decorate this function with @tf.function.