Clone_model in Callback throws TypeError NoneType

Hello,
I want to implement checkpointing of my model in the “SavedModel” format(the folder based one). Saving this model takes quite some time for my small lstm(around 10s). That’s why I wnated to clone the model and put it in a separate thread, to not hold up training for so long. When trying to clone the model the following error happens. Does anybody have any clue on why this happens?
My code to save the model in the callback looks like this:

    def on_epoch_end(self, epoch: int, logs=None):
        cloned_model = tf.keras.models.clone_model(self.model)

The error:

Traceback (most recent call last):
  File "./main.py", line 35, in <module>
    main()
  File "./main.py", line 28, in main
    compile_and_fit(model, train_gen, val_gen, timestamp, config.patience)
  File "./training/compile_and_fit.py", line 28, in compile_and_fit
    history = model.fit(train_dfs, epochs=MAX_EPOCHS,
  File "./venv/lib/python3.9/site-packages/keras/engine/training.py", line 1230, in fit
    callbacks.on_epoch_end(epoch, epoch_logs)
  File "./venv/lib/python3.9/site-packages/keras/callbacks.py", line 413, in on_epoch_end
    callback.on_epoch_end(epoch, logs)
  File "./checkpointing/MetricsCallback.py", line 48, in on_epoch_end
    cloned_model = tf.keras.models.clone_model(self.model)
  File "./venv/lib/python3.9/site-packages/keras/models.py", line 448, in clone_model
    return _clone_sequential_model(
  File "./venv/lib/python3.9/site-packages/keras/models.py", line 332, in _clone_sequential_model
    cloned_model = Sequential(layers=layers, name=model.name)
  File "./venv/lib/python3.9/site-packages/tensorflow/python/training/tracking/base.py", line 530, in _method_wrapper
    result = method(self, *args, **kwargs)
  File "./venv/lib/python3.9/site-packages/keras/engine/sequential.py", line 134, in __init__
    self.add(layer)
  File "./venv/lib/python3.9/site-packages/tensorflow/python/training/tracking/base.py", line 530, in _method_wrapper
    result = method(self, *args, **kwargs)
  File "./venv/lib/python3.9/site-packages/keras/engine/sequential.py", line 217, in add
    output_tensor = layer(self.outputs[0])
  File "./venv/lib/python3.9/site-packages/keras/layers/recurrent.py", line 659, in __call__
    return super(RNN, self).__call__(inputs, **kwargs)
  File "./venv/lib/python3.9/site-packages/keras/engine/base_layer.py", line 976, in __call__
    return self._functional_construction_call(inputs, args, kwargs,
  File "./venv/lib/python3.9/site-packages/keras/engine/base_layer.py", line 1114, in _functional_construction_call
    outputs = self._keras_tensor_symbolic_call(
  File "./venv/lib/python3.9/site-packages/keras/engine/base_layer.py", line 848, in _keras_tensor_symbolic_call
    return self._infer_output_signature(inputs, args, kwargs, input_masks)
  File "./venv/lib/python3.9/site-packages/keras/engine/base_layer.py", line 886, in _infer_output_signature
    self._maybe_build(inputs)
  File "./venv/lib/python3.9/site-packages/keras/engine/base_layer.py", line 2659, in _maybe_build
    self.build(input_shapes)  # pylint:disable=not-callable
  File "./venv/lib/python3.9/site-packages/keras/layers/recurrent.py", line 577, in build
    self.cell.build(step_input_shape)
  File "./venv/lib/python3.9/site-packages/keras/utils/tf_utils.py", line 259, in wrapper
    output_shape = fn(instance, input_shape)
  File "./venv/lib/python3.9/site-packages/keras/layers/recurrent.py", line 2354, in build
    self.kernel = self.add_weight(
  File "./venv/lib/python3.9/site-packages/keras/engine/base_layer.py", line 647, in add_weight
    variable = self._add_variable_with_custom_getter(
  File "./venv/lib/python3.9/site-packages/tensorflow/python/training/tracking/base.py", line 813, in _add_variable_with_custom_getter
    new_variable = getter(
  File "./venv/lib/python3.9/site-packages/keras/engine/base_layer_utils.py", line 117, in make_variable
    return tf.compat.v1.Variable(
  File "./venv/lib/python3.9/site-packages/tensorflow/python/ops/variables.py", line 266, in __call__
    return cls._variable_v1_call(*args, **kwargs)
  File "./venv/lib/python3.9/site-packages/tensorflow/python/ops/variables.py", line 212, in _variable_v1_call
    return previous_getter(
  File "./venv/lib/python3.9/site-packages/tensorflow/python/ops/variables.py", line 67, in getter
    return captured_getter(captured_previous, **kwargs)
  File "./venv/lib/python3.9/site-packages/tensorflow/python/distribute/distribute_lib.py", line 3547, in creator
    return next_creator(**kwargs)
  File "./venv/lib/python3.9/site-packages/tensorflow/python/ops/variables.py", line 205, in <lambda>
    previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs)
  File "./venv/lib/python3.9/site-packages/tensorflow/python/ops/variable_scope.py", line 2612, in default_variable_creator
    return resource_variable_ops.ResourceVariable(
  File "./venv/lib/python3.9/site-packages/tensorflow/python/ops/variables.py", line 270, in __call__
    return super(VariableMetaclass, cls).__call__(*args, **kwargs)
  File "./venv/lib/python3.9/site-packages/tensorflow/python/ops/resource_variable_ops.py", line 1602, in __init__
    self._init_from_args(
  File "./venv/lib/python3.9/site-packages/tensorflow/python/ops/resource_variable_ops.py", line 1740, in _init_from_args
    initial_value = initial_value()
  File "./venv/lib/python3.9/site-packages/keras/initializers/initializers_v2.py", line 499, in __call__
    fan_in, fan_out = _compute_fans(shape)
  File "./venv/lib/python3.9/site-packages/keras/initializers/initializers_v2.py", line 1009, in _compute_fans
    return int(fan_in), int(fan_out)
TypeError: int() argument must be a string, a bytes-like object or a number, not 'NoneType'

Hi @supermar10

Welcome to the TensorFlow Forum!

This TypeError occurred because the clone_model() function is called on a model that has not been initialized yet or compiled. The clone_model() function needs to know the weights of the model in order to clone it.
To fix this error, you need to compile the model before using the callback on_epoch_end() to yield the model weights.

class Clone_Callback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch: int, logs=None):
        # Clone the model
        cloned_model = tf.keras.models.clone_model(self.model)
        # Save the cloned model
        cloned_model.save('cloned_model.h5')

Please refer to this replicated gist for your reference. Thank you.