Pruning after adding layer to an already pruned model

Hello folks!

I try to test a structural growing model.
Herefore i created a base model with the ResNet classifier as follows:

Input
Dense
GlobalAveragePooling2D
Flatten
Dense
Dense
Dense

After pruning the Dense layers i get:

Input
pruned_low_magnitude(Dense)
GlobalAveragePooling2D
Flatten
pruned_low_magnitude(Dense)
pruned_low_magnitude(Dense)
pruned_low_magnitude(Dense)

Then i add a Dense after the first pruned Dense layer:

Input
pruned_low_magnitude(Dense)
Dense
GlobalAveragePooling2D
Flatten
pruned_low_magnitude(Dense)
pruned_low_magnitude(Dense)
pruned_low_magnitude(Dense)

From here on my problems begin…
The steps above should be repeated until i reach a certain accuracy and size i want.

So, in the end there is something like:

Input
pruned_low_magnitude(Dense)



pruned_low_magnitude(Dense)
GlobalAveragePooling2D
Flatten
pruned_low_magnitude(Dense)
pruned_low_magnitude(Dense)
pruned_low_magnitude(Dense)

So i create a base model first, then compile and train it.
After this i prune and add a layer after the first pruned Dense.
Then i try to compile and train again.

The Cifar 10 dataset was used.

The UpdatePruningStep() callback is called after every fit().
The pruning was done via:

def apply_pruning_to_dense(self, layer):
        
        if isinstance(layer, tf.keras.layers.Dense):
            return tfmot.sparsity.keras.prune_low_magnitude(layer)
        return layer

pruned_model = tf.keras.models.clone_model(
                model,
                clone_function=self.apply_pruning_to_dense,
                )

i receive the error:

InvalidArgumentError                      Traceback (most recent call last)
Cell In[3], line 20
      6 epochs = 1
      8 nv.save_path = os.path.join(
      9     PWD,
     10     'tmp',
   (...)
     17     str(depth)
     18 )
---> 20 nv.train_model_iteratively(base_model, iterations, DENSE_16, epochs)

File /tf/notebooks/naive.py:364, in naive.train_model_iteratively(self, model, iterations, new_layer_attributes, epochs, layer_id)
    362 compiled_model = self.compile_extended_model(pruned_model)
    363 print(str(datetime.now()) + ": Training 2/2 " + str(i) + ' Pruning: ' + str(pruned))
--> 364 current_model = self.train_model(compiled_model, epochs, pruned)
    365 pruned_model = current_model
    367 #if i > 0:

File /tf/notebooks/naive.py:228, in naive.train_model(self, model, epochs, pruned)
    225 cbs = self.update_Pruning_Step(pruned)
    227 try:    
--> 228     history = model.fit(self.training_images,
    229                         self.training_labels,
    230                         epochs=epochs, 
    231                         validation_data = (self.validation_images, self.validation_labels), 
    232                         batch_size=self.batchsize,
    233                         callbacks= cbs,
    234                         verbose= self.debug,
    235                         use_multiprocessing=True
    236                         )
    237 except tf.errors.ResourceExhaustedError:
    238     print("NOT ENOUGH RESOURCES!\n!!! SKIPPED !!!") 

File /usr/local/lib/python3.11/dist-packages/keras/src/utils/traceback_utils.py:61, in filter_traceback.<locals>.error_handler(*args, **kwargs)
     59 def error_handler(*args, **kwargs):
     60     if not tf.debugging.is_traceback_filtering_enabled():
---> 61         return fn(*args, **kwargs)
     63     filtered_tb = None
     64     try:

File /usr/local/lib/python3.11/dist-packages/keras/src/engine/training.py:1783, in Model.fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
   1775 with tf.profiler.experimental.Trace(
   1776     "train",
   1777     epoch_num=epoch,
   (...)
   1780     _r=1,
   1781 ):
   1782     callbacks.on_train_batch_begin(step)
-> 1783     tmp_logs = self.train_function(iterator)
   1784     if data_handler.should_sync:
   1785         context.async_wait()

File /usr/local/lib/python3.11/dist-packages/tensorflow/python/util/traceback_utils.py:141, in filter_traceback.<locals>.error_handler(*args, **kwargs)
    139 try:
    140   if not is_traceback_filtering_enabled():
--> 141     return fn(*args, **kwargs)
    142 except NameError:
    143   # In some very rare cases,
    144   # `is_traceback_filtering_enabled` (from the outer scope) may not be
    145   # accessible from inside this function
    146   return fn(*args, **kwargs)

File /usr/local/lib/python3.11/dist-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py:831, in Function.__call__(self, *args, **kwds)
    828 compiler = "xla" if self._jit_compile else "nonXla"
    830 with OptionalXlaContext(self._jit_compile):
--> 831   result = self._call(*args, **kwds)
    833 new_tracing_count = self.experimental_get_tracing_count()
    834 without_tracing = (tracing_count == new_tracing_count)

File /usr/local/lib/python3.11/dist-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py:904, in Function._call(self, *args, **kwds)
    900     pass  # Fall through to cond-based initialization.
    901   else:
    902     # Lifting succeeded, so variables are initialized and we can run the
    903     # no_variable_creation function.
--> 904     return tracing_compilation.call_function(
    905         args, kwds, self._no_variable_creation_config
    906     )
    907 else:
    908   bound_args = self._concrete_variable_creation_fn.function_type.bind(
    909       *args, **kwds
    910   )

File /usr/local/lib/python3.11/dist-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py:139, in call_function(args, kwargs, tracing_options)
    137 bound_args = function.function_type.bind(*args, **kwargs)
    138 flat_inputs = function.function_type.unpack_inputs(bound_args)
--> 139 return function._call_flat(  # pylint: disable=protected-access
    140     flat_inputs, captured_inputs=function.captured_inputs
    141 )

File /usr/local/lib/python3.11/dist-packages/tensorflow/python/eager/polymorphic_function/concrete_function.py:1264, in ConcreteFunction._call_flat(self, tensor_inputs, captured_inputs)
   1260 possible_gradient_type = gradients_util.PossibleTapeGradientTypes(args)
   1261 if (possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_NONE
   1262     and executing_eagerly):
   1263   # No tape is watching; skip to running the function.
-> 1264   return self._inference_function.flat_call(args)
   1265 forward_backward = self._select_forward_and_backward_functions(
   1266     args,
   1267     possible_gradient_type,
   1268     executing_eagerly)
   1269 forward_function, args_with_tangents = forward_backward.forward()

File /usr/local/lib/python3.11/dist-packages/tensorflow/python/eager/polymorphic_function/atomic_function.py:217, in AtomicFunction.flat_call(self, args)
    215 def flat_call(self, args: Sequence[core.Tensor]) -> Any:
    216   """Calls with tensor inputs and returns the structured output."""
--> 217   flat_outputs = self(*args)
    218   return self.function_type.pack_output(flat_outputs)

File /usr/local/lib/python3.11/dist-packages/tensorflow/python/eager/polymorphic_function/atomic_function.py:252, in AtomicFunction.__call__(self, *args)
    250 with record.stop_recording():
    251   if self._bound_context.executing_eagerly():
--> 252     outputs = self._bound_context.call_function(
    253         self.name,
    254         list(args),
    255         len(self.function_type.flat_outputs),
    256     )
    257   else:
    258     outputs = make_call_op_in_graph(
    259         self,
    260         list(args),
    261         self._bound_context.function_call_options.as_attrs(),
    262     )

File /usr/local/lib/python3.11/dist-packages/tensorflow/python/eager/context.py:1479, in Context.call_function(self, name, tensor_inputs, num_outputs)
   1477 cancellation_context = cancellation.context()
   1478 if cancellation_context is None:
-> 1479   outputs = execute.execute(
   1480       name.decode("utf-8"),
   1481       num_outputs=num_outputs,
   1482       inputs=tensor_inputs,
   1483       attrs=attrs,
   1484       ctx=self,
   1485   )
   1486 else:
   1487   outputs = execute.execute_with_cancellation(
   1488       name.decode("utf-8"),
   1489       num_outputs=num_outputs,
   (...)
   1493       cancellation_manager=cancellation_context,
   1494   )

File /usr/local/lib/python3.11/dist-packages/tensorflow/python/eager/execute.py:60, in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     53   # Convert any objects of type core_types.Tensor to Tensor.
     54   inputs = [
     55       tensor_conversion_registry.convert(t)
     56       if isinstance(t, core_types.Tensor)
     57       else t
     58       for t in inputs
     59   ]
---> 60   tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
     61                                       inputs, attrs, num_outputs)
     62 except core._NotOkStatusException as e:
     63   if name is not None:

InvalidArgumentError: Graph execution error:

Detected at node model/prune_low_magnitude_iteratively_added_dense_2/assert_greater_equal/Assert/AssertGuard/Assert defined at (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main

  File "<frozen runpy>", line 88, in _run_code

  File "/usr/local/lib/python3.11/dist-packages/ipykernel_launcher.py", line 17, in <module>

  File "/usr/local/lib/python3.11/dist-packages/traitlets/config/application.py", line 1046, in launch_instance

  File "/usr/local/lib/python3.11/dist-packages/ipykernel/kernelapp.py", line 736, in start

  File "/usr/local/lib/python3.11/dist-packages/tornado/platform/asyncio.py", line 195, in start

  File "/usr/lib/python3.11/asyncio/base_events.py", line 604, in run_forever

  File "/usr/lib/python3.11/asyncio/base_events.py", line 1909, in _run_once

  File "/usr/lib/python3.11/asyncio/events.py", line 80, in _run

  File "/usr/local/lib/python3.11/dist-packages/ipykernel/kernelbase.py", line 516, in dispatch_queue

  File "/usr/local/lib/python3.11/dist-packages/ipykernel/kernelbase.py", line 505, in process_one

  File "/usr/local/lib/python3.11/dist-packages/ipykernel/kernelbase.py", line 412, in dispatch_shell

  File "/usr/local/lib/python3.11/dist-packages/ipykernel/kernelbase.py", line 740, in execute_request

  File "/usr/local/lib/python3.11/dist-packages/ipykernel/ipkernel.py", line 422, in do_execute

  File "/usr/local/lib/python3.11/dist-packages/ipykernel/zmqshell.py", line 546, in run_cell

  File "/usr/local/lib/python3.11/dist-packages/IPython/core/interactiveshell.py", line 3024, in run_cell

  File "/usr/local/lib/python3.11/dist-packages/IPython/core/interactiveshell.py", line 3079, in _run_cell

  File "/usr/local/lib/python3.11/dist-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner

  File "/usr/local/lib/python3.11/dist-packages/IPython/core/interactiveshell.py", line 3284, in run_cell_async

  File "/usr/local/lib/python3.11/dist-packages/IPython/core/interactiveshell.py", line 3466, in run_ast_nodes

  File "/usr/local/lib/python3.11/dist-packages/IPython/core/interactiveshell.py", line 3526, in run_code

  File "/tmp/ipykernel_20788/595425178.py", line 20, in <module>

  File "/tf/notebooks/naive.py", line 364, in train_model_iteratively

  File "/tf/notebooks/naive.py", line 228, in train_model

  File "/usr/local/lib/python3.11/dist-packages/keras/src/utils/traceback_utils.py", line 61, in error_handler

  File "/usr/local/lib/python3.11/dist-packages/keras/src/engine/training.py", line 1783, in fit

  File "/usr/local/lib/python3.11/dist-packages/keras/src/engine/training.py", line 1377, in train_function

  File "/usr/local/lib/python3.11/dist-packages/keras/src/engine/training.py", line 1360, in step_function

  File "/usr/local/lib/python3.11/dist-packages/keras/src/engine/training.py", line 1349, in run_step

  File "/usr/local/lib/python3.11/dist-packages/keras/src/engine/training.py", line 1126, in train_step

  File "/usr/local/lib/python3.11/dist-packages/keras/src/utils/traceback_utils.py", line 61, in error_handler

  File "/usr/local/lib/python3.11/dist-packages/keras/src/engine/training.py", line 589, in __call__

  File "/usr/local/lib/python3.11/dist-packages/keras/src/utils/traceback_utils.py", line 61, in error_handler

  File "/usr/local/lib/python3.11/dist-packages/keras/src/engine/base_layer.py", line 1149, in __call__

  File "/usr/local/lib/python3.11/dist-packages/keras/src/utils/traceback_utils.py", line 96, in error_handler

  File "/usr/local/lib/python3.11/dist-packages/keras/src/engine/functional.py", line 515, in call

  File "/usr/local/lib/python3.11/dist-packages/keras/src/engine/functional.py", line 672, in _run_internal_graph

  File "/usr/local/lib/python3.11/dist-packages/keras/src/utils/traceback_utils.py", line 61, in error_handler

  File "/usr/local/lib/python3.11/dist-packages/keras/src/engine/base_layer.py", line 1149, in __call__

  File "/usr/local/lib/python3.11/dist-packages/keras/src/utils/traceback_utils.py", line 96, in error_handler

  File "/usr/local/lib/python3.11/dist-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py", line 303, in call

  File "/usr/local/lib/python3.11/dist-packages/tensorflow_model_optimization/python/core/keras/utils.py", line 51, in smart_cond

  File "/usr/local/lib/python3.11/dist-packages/tensorflow_model_optimization/python/core/keras/utils.py", line 55, in smart_cond

  File "/usr/local/lib/python3.11/dist-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py", line 285, in add_update

Detected at node model/prune_low_magnitude_iteratively_added_dense_2/assert_greater_equal/Assert/AssertGuard/Assert defined at (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main

  File "<frozen runpy>", line 88, in _run_code

  File "/usr/local/lib/python3.11/dist-packages/ipykernel_launcher.py", line 17, in <module>

  File "/usr/local/lib/python3.11/dist-packages/traitlets/config/application.py", line 1046, in launch_instance

  File "/usr/local/lib/python3.11/dist-packages/ipykernel/kernelapp.py", line 736, in start

  File "/usr/local/lib/python3.11/dist-packages/tornado/platform/asyncio.py", line 195, in start

  File "/usr/lib/python3.11/asyncio/base_events.py", line 604, in run_forever

  File "/usr/lib/python3.11/asyncio/base_events.py", line 1909, in _run_once

  File "/usr/lib/python3.11/asyncio/events.py", line 80, in _run

  File "/usr/local/lib/python3.11/dist-packages/ipykernel/kernelbase.py", line 516, in dispatch_queue

  File "/usr/local/lib/python3.11/dist-packages/ipykernel/kernelbase.py", line 505, in process_one

  File "/usr/local/lib/python3.11/dist-packages/ipykernel/kernelbase.py", line 412, in dispatch_shell

  File "/usr/local/lib/python3.11/dist-packages/ipykernel/kernelbase.py", line 740, in execute_request

  File "/usr/local/lib/python3.11/dist-packages/ipykernel/ipkernel.py", line 422, in do_execute

  File "/usr/local/lib/python3.11/dist-packages/ipykernel/zmqshell.py", line 546, in run_cell

  File "/usr/local/lib/python3.11/dist-packages/IPython/core/interactiveshell.py", line 3024, in run_cell

  File "/usr/local/lib/python3.11/dist-packages/IPython/core/interactiveshell.py", line 3079, in _run_cell

  File "/usr/local/lib/python3.11/dist-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner

  File "/usr/local/lib/python3.11/dist-packages/IPython/core/interactiveshell.py", line 3284, in run_cell_async

  File "/usr/local/lib/python3.11/dist-packages/IPython/core/interactiveshell.py", line 3466, in run_ast_nodes

  File "/usr/local/lib/python3.11/dist-packages/IPython/core/interactiveshell.py", line 3526, in run_code

  File "/tmp/ipykernel_20788/595425178.py", line 20, in <module>

  File "/tf/notebooks/naive.py", line 364, in train_model_iteratively

  File "/tf/notebooks/naive.py", line 228, in train_model

  File "/usr/local/lib/python3.11/dist-packages/keras/src/utils/traceback_utils.py", line 61, in error_handler

  File "/usr/local/lib/python3.11/dist-packages/keras/src/engine/training.py", line 1783, in fit

  File "/usr/local/lib/python3.11/dist-packages/keras/src/engine/training.py", line 1377, in train_function

  File "/usr/local/lib/python3.11/dist-packages/keras/src/engine/training.py", line 1360, in step_function

  File "/usr/local/lib/python3.11/dist-packages/keras/src/engine/training.py", line 1349, in run_step

  File "/usr/local/lib/python3.11/dist-packages/keras/src/engine/training.py", line 1126, in train_step

  File "/usr/local/lib/python3.11/dist-packages/keras/src/utils/traceback_utils.py", line 61, in error_handler

  File "/usr/local/lib/python3.11/dist-packages/keras/src/engine/training.py", line 589, in __call__

  File "/usr/local/lib/python3.11/dist-packages/keras/src/utils/traceback_utils.py", line 61, in error_handler

  File "/usr/local/lib/python3.11/dist-packages/keras/src/engine/base_layer.py", line 1149, in __call__

  File "/usr/local/lib/python3.11/dist-packages/keras/src/utils/traceback_utils.py", line 96, in error_handler

  File "/usr/local/lib/python3.11/dist-packages/keras/src/engine/functional.py", line 515, in call

  File "/usr/local/lib/python3.11/dist-packages/keras/src/engine/functional.py", line 672, in _run_internal_graph

  File "/usr/local/lib/python3.11/dist-packages/keras/src/utils/traceback_utils.py", line 61, in error_handler

  File "/usr/local/lib/python3.11/dist-packages/keras/src/engine/base_layer.py", line 1149, in __call__

  File "/usr/local/lib/python3.11/dist-packages/keras/src/utils/traceback_utils.py", line 96, in error_handler

  File "/usr/local/lib/python3.11/dist-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py", line 303, in call

  File "/usr/local/lib/python3.11/dist-packages/tensorflow_model_optimization/python/core/keras/utils.py", line 51, in smart_cond

  File "/usr/local/lib/python3.11/dist-packages/tensorflow_model_optimization/python/core/keras/utils.py", line 55, in smart_cond

  File "/usr/local/lib/python3.11/dist-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py", line 285, in add_update

2 root error(s) found.
  (0) INVALID_ARGUMENT:  assertion failed: [Prune() wrapper requires the UpdatePruningStep callback to be provided during training. Please add it as a callback to your model.fit call.] [Condition x >= y did not hold element-wise:] [x (model/prune_low_magnitude_iteratively_added_dense_2/assert_greater_equal/ReadVariableOp:0) = ] [0] [y (model/prune_low_magnitude_iteratively_added_dense_2/assert_greater_equal/y:0) = ] [1]
	 [[{{node model/prune_low_magnitude_iteratively_added_dense_2/assert_greater_equal/Assert/AssertGuard/Assert}}]]
	 [[model/prune_low_magnitude_rn_classification/assert_greater_equal/Assert/AssertGuard/pivot_f/_113/_139]]
  (1) INVALID_ARGUMENT:  assertion failed: [Prune() wrapper requires the UpdatePruningStep callback to be provided during training. Please add it as a callback to your model.fit call.] [Condition x >= y did not hold element-wise:] [x (model/prune_low_magnitude_iteratively_added_dense_2/assert_greater_equal/ReadVariableOp:0) = ] [0] [y (model/prune_low_magnitude_iteratively_added_dense_2/assert_greater_equal/y:0) = ] [1]
	 [[{{node model/prune_low_magnitude_iteratively_added_dense_2/assert_greater_equal/Assert/AssertGuard/Assert}}]]
0 successful operations.
0 derived errors ignored. [Op:__inference_train_function_37577]