Implementing unet with tensorflow subclassing API does not work

I am trying to implement Unet with TensorFlow subclassing API and something does not seem to work properly, and I get the following error:

OperatorNotAllowedInGraphError: iterating over `tf.Tensor` is not allowed in Graph execution. Use Eager execution or decorate this function with @tf.function.

Furthermore, I am uncertain if I have correctly implemented the logic inside the call() function. Any help to correct my mistakes would be much appreciated.

Here I am attaching the full copy of the implementation and the error tracks:

Code Implementation:


from functools import partial

keras.backend.clear_session()
tf.random.set_seed(42)
np.random.seed(42)

conv2d = partial(keras.layers.Conv2D, kernel_size = 3, 
                    padding = 'SAME',
                    kernel_initializer = 'he_normal', 
                    use_bias = False)

conv2dtranspose = partial(keras.layers.Conv2DTranspose, 
                          kernel_size = 2, strides = 2, 
                          padding = 'SAME')

class encoder(keras.layers.Layer):
    def __init__(self, filters, **kwargs):
        super(encoder, self).__init__(**kwargs)
        
        self.convs = [
            conv2d(filters),
            keras.layers.BatchNormalization(),
            keras.layers.Activation('relu'),
            conv2d(filters),
            keras.layers.BatchNormalization(),
            keras.layers.Activation('relu')
                         ]
        
    def call(self, inputs):
        Z = inputs
        for layer in self.convs:
            Z = layer(Z)
        return Z


class UNet(keras.models.Model):
    def __init__(self, filters, inputs_shape = [128, 128, 1], **kwargs):
        super(UNet, self).__init__(**kwargs)
        self.filters = filters
        self.inputs = keras.layers.Input(shape = inputs_shape)
        self.maxpool2d = keras.layers.MaxPool2D(pool_size = (2, 2), strides = 2)
        self.conv2dtranspose = conv2dtranspose
        self.concat = keras.layers.Concatenate()
        
        
    def call(self, inputs):
        skips = {}

        Z, inpt = inputs

        #implementing encoder path
        for fId in range(len(self.filters)):
            Z = encoder(filters = self.filters[fId])(Z)
            if fId < len(self.filters) - 1:
                skips[fId] = Z
                Z = self.maxpool2d(Z)

        #implementing decoder path
        for fId in reversed(range(len(self.filters) - 1)):
            Z = self.conv2dtranspose(self.filters[fId])(Z)
            Z = self.concat([Z, skips[fId]])
            Z = encoder(self.filters[::-1][fId])(Z)
        
        output =  keras.layers.Conv2D(1, kernel_size = 1, activation = 'sigmoid')(Z)    
        return keras.Model(inputs = [inpt], outputs = [output])

    
filters = [64, 128, 256, 512]

inpt = keras.layers.Input(shape = [128, 128, 1])
model = UNet(filters = filters)(inpt)

#Generating some test data

x = tf.random.normal(shape = (10, 128, 128, 1))
y = tf.random.normal(shape = (10, 128, 128, 1))


model.compile(loss = 'binary_crossentropy', optimizer = keras.optimizers.SGD(), metrics = ['accuracy'])
model.fit(x, y, epochs = 3)

Error Tracks:


WARNING:tensorflow:AutoGraph could not transform <bound method UNet.call of <__main__.UNet object at 0x2930b3d30>> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: module 'gast' has no attribute 'Index'
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
WARNING: AutoGraph could not transform <bound method UNet.call of <__main__.UNet object at 0x2930b3d30>> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: module 'gast' has no attribute 'Index'
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py in converted_call(f, args, kwargs, caller_fn_scope, options)
    446     program_ctx = converter.ProgramContext(options=options)
--> 447     converted_f = _convert_actual(target_entity, program_ctx)
    448     if logging.has_verbosity(2):

~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py in _convert_actual(entity, program_ctx)
    283 
--> 284   transformed, module, source_map = _TRANSPILER.transform(entity, program_ctx)
    285 

~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/autograph/pyct/transpiler.py in transform(self, obj, user_context)
    285     if inspect.isfunction(obj) or inspect.ismethod(obj):
--> 286       return self.transform_function(obj, user_context)
    287 

~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/autograph/pyct/transpiler.py in transform_function(self, fn, user_context)
    469           # TODO(mdan): Confusing overloading pattern. Fix.
--> 470           nodes, ctx = super(PyToPy, self).transform_function(fn, user_context)
    471 

~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/autograph/pyct/transpiler.py in transform_function(self, fn, user_context)
    362     node = self._erase_arg_defaults(node)
--> 363     result = self.transform_ast(node, context)
    364 

~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py in transform_ast(self, node, ctx)
    251     unsupported_features_checker.verify(node)
--> 252     node = self.initial_analysis(node, ctx)
    253 

~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py in initial_analysis(self, node, ctx)
    238     graphs = cfg.build(node)
--> 239     node = qual_names.resolve(node)
    240     node = activity.resolve(node, ctx, None)

~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/autograph/pyct/qual_names.py in resolve(node)
    251 def resolve(node):
--> 252   return QnResolver().visit(node)
    253 

~/miniforge3/envs/mlm1-engine/lib/python3.8/ast.py in visit(self, node)
    370         visitor = getattr(self, method, self.generic_visit)
--> 371         return visitor(node)
    372 

~/miniforge3/envs/mlm1-engine/lib/python3.8/ast.py in generic_visit(self, node)
    446                     if isinstance(value, AST):
--> 447                         value = self.visit(value)
    448                         if value is None:

~/miniforge3/envs/mlm1-engine/lib/python3.8/ast.py in visit(self, node)
    370         visitor = getattr(self, method, self.generic_visit)
--> 371         return visitor(node)
    372 

~/miniforge3/envs/mlm1-engine/lib/python3.8/ast.py in generic_visit(self, node)
    446                     if isinstance(value, AST):
--> 447                         value = self.visit(value)
    448                         if value is None:

~/miniforge3/envs/mlm1-engine/lib/python3.8/ast.py in visit(self, node)
    370         visitor = getattr(self, method, self.generic_visit)
--> 371         return visitor(node)
    372 

~/miniforge3/envs/mlm1-engine/lib/python3.8/ast.py in generic_visit(self, node)
    455             elif isinstance(old_value, AST):
--> 456                 new_node = self.visit(old_value)
    457                 if new_node is None:

~/miniforge3/envs/mlm1-engine/lib/python3.8/ast.py in visit(self, node)
    370         visitor = getattr(self, method, self.generic_visit)
--> 371         return visitor(node)
    372 

~/miniforge3/envs/mlm1-engine/lib/python3.8/ast.py in generic_visit(self, node)
    455             elif isinstance(old_value, AST):
--> 456                 new_node = self.visit(old_value)
    457                 if new_node is None:

~/miniforge3/envs/mlm1-engine/lib/python3.8/ast.py in visit(self, node)
    370         visitor = getattr(self, method, self.generic_visit)
--> 371         return visitor(node)
    372 

~/miniforge3/envs/mlm1-engine/lib/python3.8/ast.py in generic_visit(self, node)
    446                     if isinstance(value, AST):
--> 447                         value = self.visit(value)
    448                         if value is None:

~/miniforge3/envs/mlm1-engine/lib/python3.8/ast.py in visit(self, node)
    370         visitor = getattr(self, method, self.generic_visit)
--> 371         return visitor(node)
    372 

~/miniforge3/envs/mlm1-engine/lib/python3.8/ast.py in generic_visit(self, node)
    455             elif isinstance(old_value, AST):
--> 456                 new_node = self.visit(old_value)
    457                 if new_node is None:

~/miniforge3/envs/mlm1-engine/lib/python3.8/ast.py in visit(self, node)
    370         visitor = getattr(self, method, self.generic_visit)
--> 371         return visitor(node)
    372 

~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/autograph/pyct/qual_names.py in visit_Subscript(self, node)
    231     s = node.slice
--> 232     if not isinstance(s, gast.Index):
    233       # TODO(mdan): Support range and multi-dimensional indices.

AttributeError: module 'gast' has no attribute 'Index'

During handling of the above exception, another exception occurred:

OperatorNotAllowedInGraphError            Traceback (most recent call last)
<ipython-input-449-e6f92329b0db> in <module>
      2 
      3 inpt = keras.layers.Input(shape = [128, 128, 1])
----> 4 model = UNet(filters = filters)(inpt)
      5 
      6 #Generating some test data

~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py in __call__(self, *args, **kwargs)
    944     # >> model = tf.keras.Model(inputs, outputs)
    945     if _in_functional_construction_mode(self, inputs, args, kwargs, input_list):
--> 946       return self._functional_construction_call(inputs, args, kwargs,
    947                                                 input_list)
    948 

~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py in _functional_construction_call(self, inputs, args, kwargs, input_list)
   1083           layer=self, inputs=inputs, build_graph=True, training=training_value):
   1084         # Check input assumptions set after layer building, e.g. input shape.
-> 1085         outputs = self._keras_tensor_symbolic_call(
   1086             inputs, input_masks, args, kwargs)
   1087 

~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py in _keras_tensor_symbolic_call(self, inputs, input_masks, args, kwargs)
    815       return nest.map_structure(keras_tensor.KerasTensor, output_signature)
    816     else:
--> 817       return self._infer_output_signature(inputs, args, kwargs, input_masks)
    818 
    819   def _infer_output_signature(self, inputs, args, kwargs, input_masks):

~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py in _infer_output_signature(self, inputs, args, kwargs, input_masks)
    856           # TODO(kaftan): do we maybe_build here, or have we already done it?
    857           self._maybe_build(inputs)
--> 858           outputs = call_fn(inputs, *args, **kwargs)
    859 
    860         self._handle_activity_regularization(inputs, outputs)

~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py in wrapper(*args, **kwargs)
    665       try:
    666         with conversion_ctx:
--> 667           return converted_call(f, args, kwargs, options=options)
    668       except Exception as e:  # pylint:disable=broad-except
    669         if hasattr(e, 'ag_error_metadata'):

~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py in converted_call(f, args, kwargs, caller_fn_scope, options)
    452     if is_autograph_strict_conversion_mode():
    453       raise
--> 454     return _fall_back_unconverted(f, args, kwargs, options, e)
    455 
    456   with StackTraceMapper(converted_f), tf_stack.CurrentModuleFilter():

~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py in _fall_back_unconverted(f, args, kwargs, options, exc)
    499     logging.warn(warning_template, f, file_bug_message, exc)
    500 
--> 501   return _call_unconverted(f, args, kwargs, options)
    502 
    503 

~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py in _call_unconverted(f, args, kwargs, options, update_cache)
    476 
    477   if kwargs is not None:
--> 478     return f(*args, **kwargs)
    479   return f(*args)
    480 

<ipython-input-448-ce9f55fd84b1> in call(self, inputs)
     49         skips = {}
     50 
---> 51         Z, inpt = inputs
     52 
     53         #implementing encoder path

~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/framework/ops.py in __iter__(self)
    503   def __iter__(self):
    504     if not context.executing_eagerly():
--> 505       self._disallow_iteration()
    506 
    507     shape = self._shape_tuple()

~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/framework/ops.py in _disallow_iteration(self)
    499     else:
    500       # Default: V1-style Graph execution.
--> 501       self._disallow_in_graph_mode("iterating over `tf.Tensor`")
    502 
    503   def __iter__(self):

~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/framework/ops.py in _disallow_in_graph_mode(self, task)
    477 
    478   def _disallow_in_graph_mode(self, task):
--> 479     raise errors.OperatorNotAllowedInGraphError(
    480         "{} is not allowed in Graph execution. Use Eager execution or decorate"
    481         " this function with @tf.function.".format(task))

OperatorNotAllowedInGraphError: iterating over `tf.Tensor` is not allowed in Graph execution. Use Eager execution or decorate this function with @tf.function.

Hi @Ethen_Kaufmann

Welcome to the TensorFlow Forum!

The above mentioned error occurred because you are trying to implement the older TensorFlow code which has used older deprecated APIs of TensorFlow 1.x in the TensorFlow 2.x version.

Please check and update the code to be compatible with TensorFlow 2.x after replacing the existing APIs with the new TensorFlow 2.x APIs for the same functionality using mentioned link as a reference.

Please try again by doing above changes and let us know if the issue still persists. Thank you.