Does the CRFModelWrapper in the TFA support the serialization for the continuous training?

Hi everyone,

I am using the CRFModelWrapper method following the tutorial as addons/layers_crf.ipynb at add_crf_tutorial · howl-anderson/addons · GitHub to implement a Bi-LSTM -CRF neural-network for a multi-classes NER problem. The model I built (codes are shown below) can be trained with multiple GPUs and it can be saved and load with the tf.keras.model functions. However, when I I saved the model, a warning shows below which I am not sure if it really matters. After that, I loaded the trained model, it shows many warnings related inconsistent output shape? These warnings are posted below. want to train it again using the fit function, it shows no

Saving warning:

WARNING:absl:Found untraced functions such as embedding_layer_call_and_return_conditional_losses, embedding_layer_call_fn, embedding_1_layer_call_and_return_conditional_losses, embedding_1_layer_call_fn, multi_head_attention_layer_call_and_return_conditional_losses while saving (showing 5 of 65). These functions will not be directly callable after loading.

Loading warnings:

2021-11-21 00:28:34.222763: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond/while' has 14 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:34.374067: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond' has 5 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:35.553642: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond/while' has 14 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:37.286608: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond/while' has 14 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:37.556591: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond/while' has 14 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:37.645775: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond' has 5 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:37.742369: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond' has 5 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:38.758195: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond' has 5 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:38.892746: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond/while' has 14 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:38.905252: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond' has 5 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:39.369396: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond/while' has 14 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:39.426702: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond/while' has 14 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:39.439491: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond' has 5 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:39.562408: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond/while' has 14 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:39.574280: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond' has 5 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:40.102410: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond/while' has 14 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:40.363686: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond/while' has 14 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:40.375926: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond' has 5 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:40.706863: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond/while' has 14 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:40.765348: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond/while' has 14 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:41.099845: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond/while' has 14 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:41.111754: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond' has 5 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:41.127138: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond/while' has 14 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:41.139674: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond' has 5 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:41.959556: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond/while' has 14 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:41.971252: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond' has 5 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:42.012887: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond/while' has 14 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:42.025740: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond' has 5 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:42.284787: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond' has 5 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:42.403613: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond' has 5 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:42.551636: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond/while' has 14 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:42.773302: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond/while' has 14 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:42.786014: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond' has 5 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:42.993502: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond/while' has 14 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:43.005655: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond' has 5 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:43.019730: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond/while' has 14 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:43.031896: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond' has 5 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:43.045523: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond/while' has 14 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:43.581028: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond/while' has 14 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:43.593377: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond' has 5 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:44.286465: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond/while' has 14 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:44.298637: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond' has 5 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:44.319105: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond/while' has 14 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:44.331097: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond' has 5 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:44.446932: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond/while' has 14 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:44.864126: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond/while' has 14 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:44.995524: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond' has 5 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:45.084922: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond/while' has 14 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:45.097600: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond' has 5 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:45.134094: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond' has 5 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:45.243815: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond' has 5 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:45.264186: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond/while' has 14 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:45.599941: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond' has 5 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:45.621283: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond/while' has 14 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:45.633535: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond' has 5 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:46.106656: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond/while' has 14 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:46.119119: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond' has 5 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:46.141522: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond' has 5 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:46.165137: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond' has 5 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:46.266813: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond' has 5 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.

With these warnings, the model can still be saved, loaded, and used for the prediction, the loaded model cannot be re-trained again. The error message is posted below and it seems that the loss function inside the CRFModelWrapper cannot be called again so that the gradient calculation cannot be done. So, I am wondering if the CRFModelWrapper doesn’t support the serialization (save->load->training) or it’s because of some mistakes I have made. If so, is there any way that I can workaround to retrain the model?
Thank you very much.

Error message when re-training the trained model:

Traceback (most recent call last):
  File "<input>", line 1, in <module>
  File "C:\Users\YenPangLai\anaconda3\envs\tf2p6\lib\site-packages\keras\engine\training.py", line 1184, in fit
    tmp_logs = self.train_function(iterator)
  File "C:\Users\YenPangLai\anaconda3\envs\tf2p6\lib\site-packages\tensorflow\python\eager\def_function.py", line 885, in __call__
    result = self._call(*args, **kwds)
  File "C:\Users\YenPangLai\anaconda3\envs\tf2p6\lib\site-packages\tensorflow\python\eager\def_function.py", line 933, in _call
    self._initialize(args, kwds, add_initializers_to=initializers)
  File "C:\Users\YenPangLai\anaconda3\envs\tf2p6\lib\site-packages\tensorflow\python\eager\def_function.py", line 759, in _initialize
    self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
  File "C:\Users\YenPangLai\anaconda3\envs\tf2p6\lib\site-packages\tensorflow\python\eager\function.py", line 3066, in _get_concrete_function_internal_garbage_collected
    graph_function, _ = self._maybe_define_function(args, kwargs)
  File "C:\Users\YenPangLai\anaconda3\envs\tf2p6\lib\site-packages\tensorflow\python\eager\function.py", line 3463, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "C:\Users\YenPangLai\anaconda3\envs\tf2p6\lib\site-packages\tensorflow\python\eager\function.py", line 3298, in _create_graph_function
    func_graph_module.func_graph_from_py_func(
  File "C:\Users\YenPangLai\anaconda3\envs\tf2p6\lib\site-packages\tensorflow\python\framework\func_graph.py", line 1007, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "C:\Users\YenPangLai\anaconda3\envs\tf2p6\lib\site-packages\tensorflow\python\eager\def_function.py", line 668, in wrapped_fn
    out = weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "C:\Users\YenPangLai\anaconda3\envs\tf2p6\lib\site-packages\tensorflow\python\framework\func_graph.py", line 994, in wrapper
    raise e.ag_error_metadata.to_exception(e)
ValueError: in user code:
    C:\Users\YenPangLai\anaconda3\envs\tf2p6\lib\site-packages\keras\engine\training.py:853 train_function  *
        return step_function(self, iterator)
    C:\Users\YenPangLai\anaconda3\envs\tf2p6\lib\site-packages\keras\engine\training.py:842 step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    C:\Users\YenPangLai\anaconda3\envs\tf2p6\lib\site-packages\tensorflow\python\distribute\distribute_lib.py:1286 run
        return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
    C:\Users\YenPangLai\anaconda3\envs\tf2p6\lib\site-packages\tensorflow\python\distribute\distribute_lib.py:2849 call_for_each_replica
        return self._call_for_each_replica(fn, args, kwargs)
    C:\Users\YenPangLai\anaconda3\envs\tf2p6\lib\site-packages\tensorflow\python\distribute\distribute_lib.py:3632 _call_for_each_replica
        return fn(*args, **kwargs)
    C:\Users\YenPangLai\anaconda3\envs\tf2p6\lib\site-packages\keras\engine\training.py:835 run_step  **
        outputs = model.train_step(data)
    C:\Users\YenPangLai\anaconda3\envs\tf2p6\lib\site-packages\keras\engine\training.py:791 train_step
        self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
    C:\Users\YenPangLai\anaconda3\envs\tf2p6\lib\site-packages\keras\optimizer_v2\optimizer_v2.py:522 minimize
        return self.apply_gradients(grads_and_vars, name=name)
    C:\Users\YenPangLai\anaconda3\envs\tf2p6\lib\site-packages\keras\optimizer_v2\optimizer_v2.py:622 apply_gradients
        grads_and_vars = optimizer_utils.filter_empty_gradients(grads_and_vars)
    C:\Users\YenPangLai\anaconda3\envs\tf2p6\lib\site-packages\keras\optimizer_v2\utils.py:72 filter_empty_gradients
        raise ValueError("No gradients provided for any variable: %s." %
    ValueError: No gradients provided for any variable: ['chain_kernel:0', 'left_boundary:0', 'right_boundary:0', 'crf_model_wrapper/crf/dense/kernel:0', 'crf_model_wrapper/crf/dense/bias:0', 'embedding/embeddings:0', 'bidirectional/forward_bilstm/lstm_cell_1/kernel:0', 'bidirectional/forward_bilstm/lstm_cell_1/recurrent_kernel:0', 'bidirectional/forward_bilstm/lstm_cell_1/bias:0', 'bidirectional/backward_bilstm/lstm_cell_2/kernel:0', 'bidirectional/backward_bilstm/lstm_cell_2/recurrent_kernel:0', 'bidirectional/backward_bilstm/lstm_cell_2/bias:0', 'time_distributed/kernel:0', 'time_distributed/bias:0'].

Below shows my codes except for the data-preprocessing.

#%% Build the base model_1
def build_bilstm_crf_model(
        lstm_unit,
        fc_unit
) -> tf.keras.Model:
    x = tf.keras.layers.Input(shape=(None,), dtype=tf.float32, name="inn")
    y = tf.keras.layers.Embedding(1, 1, mask_zero=True)(x)
    y = tf.keras.layers.Bidirectional(
        tf.keras.layers.LSTM(lstm_unit, return_sequences=True,name="bilstm")
    )(y)
    y = tf.keras.layers.TimeDistributed(
        tf.keras.layers.Dense(fc_unit,name="fc")
    )(y)
    return tf.keras.Model(
        inputs=x, outputs=y
    )

# CFR Wrapper Model
class CRFModelWrapper(tf.keras.Model):
    def __init__(
            self,
            model: tf.keras.Model,
            units: int,
            chain_initializer="orthogonal",
            use_boundary: bool = True,
            boundary_initializer="zeros",
            use_kernel: bool = True,
            **kwargs
    ):
        super().__init__()

        self.crf_layer = tfa.layers.CRF(
            units=units,
            chain_initializer=chain_initializer,
            use_boundary=use_boundary,
            boundary_initializer=boundary_initializer,
            use_kernel=use_kernel,
            **kwargs
        )

        self.base_model = model

    def unpack_training_data(self, data):
        # override me, if this is not suit for your task
        if len(data) == 3:
            x, y, sample_weight = data
        else:
            x, y = data
            sample_weight = None
        return x, y, sample_weight

    def call(self, inputs, training=None, mask=None, return_crf_internal=False):
        base_model_outputs = self.base_model(inputs, training, mask)

        # change next line, if your model has more outputs
        crf_input = base_model_outputs

        decode_sequence, potentials, sequence_length, kernel = self.crf_layer(crf_input)
        ### potentials =predicted y during training
        # change next line, if your base model has more outputs
        # Always keep `(potentials, sequence_length, kernel), decode_sequence, `
        # as first two outputs of model.
        # current `self.train_step()` expected such settings
        outputs = (potentials, sequence_length, kernel), decode_sequence

        if return_crf_internal:
            return outputs
        else:
            # outputs[0] is the crf internal, skip it
            output_without_crf_internal = outputs[1:]

            # it is nicer to return a tensor instead of an one tensor list
            if len(output_without_crf_internal) == 1:
                return output_without_crf_internal[0]
            else:
                return output_without_crf_internal

    def compute_crf_loss(self, potentials, sequence_length, kernel, y, sample_weight=None):
        ### Added to reshape labels(y)
        shape = y.shape
        if len(shape) > 2:
            y_1 = tf.argmax(y, -1, output_type=tf.int32)
        ################################################
        crf_likelihood, _ = tfa.text.crf_log_likelihood(
            potentials, y_1, sequence_length, kernel
        )
        # convert likelihood to loss
        flat_crf_loss = -1 * crf_likelihood
        if sample_weight is not None:
            flat_crf_loss = flat_crf_loss * sample_weight
        crf_loss = tf.reduce_mean(flat_crf_loss)

        return crf_loss

    def train_step(self, data):
        x, y, sample_weight = self.unpack_training_data(data)
        with tf.GradientTape() as tape:
            (potentials, sequence_length, kernel), decoded_sequence, *_ = self(
                x, training=True, return_crf_internal=True
            )
            crf_loss = self.compute_crf_loss(
                potentials, sequence_length, kernel, y, sample_weight
            )
            loss = crf_loss + tf.reduce_sum(self.losses)
        gradients = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))

        # Update metrics (includes the metric that tracks the loss)
        self.compiled_metrics.update_state(y, decoded_sequence)
        # Return a dict mapping metric names to current value
        results = {m.name: m.result() for m in self.metrics}
        results.update({"loss": loss, "crf_loss": crf_loss})  # append loss
        return results

    def test_step(self, data):
        x, y, sample_weight = self.unpack_training_data(data)
        (potentials, sequence_length, kernel), decode_sequence, *_ = self(
            x, training=False, return_crf_internal=True
        )
        crf_loss = self.compute_crf_loss(
            potentials, sequence_length, kernel, y, sample_weight
        )
        loss = crf_loss + tf.reduce_sum(self.losses)
        # Update metrics (includes the metric that tracks the loss)
        self.compiled_metrics.update_state(y, decode_sequence)
        # Return a dict mapping metric names to current value
        results = {m.name: m.result() for m in self.metrics}
        results.update({"loss": loss, "crf_loss": crf_loss})  # append loss
        return results

When training the model:

strategy = tf.distribute.MirroredStrategy()
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
with strategy.scope():
    base_model = build_bilstm_crf_model(units_lstm,TAG_SIZE)
    model = CRFModelWrapper(base_model, TAG_SIZE)
    model.compile(optimizer=tf.keras.optimizers.Adam(lr))
num_epochs = 10
name  = 'BLD_CRF_Lut{}_lr{}_{}'.format(units_lstm,lrr,int(time.time()))
name_csv = 'BLD_CRF_Lut{}_lr{}'.format(units_lstm,lrr)
mc = tf.keras.callbacks.ModelCheckpoint(filepath=os.path.join(name,'{epoch:02d}'), verbose=1,
                     save_best_only=False,save_weights_only=False)
csv_log = tf.keras.callbacks.CSVLogger('{}.csv'.format(name_csv),append=True)
model.fit(tr_gen, epochs=num_epochs, validation_data = val_gen,verbose=2,batch_size = bt_sz,callbacks=[mc,csv_log])

When re-training the model:

#%%
strategy = tf.distribute.MirroredStrategy()
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
with strategy.scope():
# base_model = build_bilstm_crf_model(units_lstm,TAG_SIZE)
# model = CRFModelWrapper(base_model, TAG_SIZE)
# model.compile(optimizer=tf.keras.optimizers.Adam(lr))
    model = load_model(FILE_PATH4)
num_epochs = 10
name  = 'BLD_CRF_Lut{}_lr{}_{}'.format(units_lstm,lrr,int(time.time()))
name_csv = 'BLD_CRF_Lut{}_lr{}'.format(units_lstm,lrr)
mc = tf.keras.callbacks.ModelCheckpoint(filepath=os.path.join(name,'{epoch:02d}'),verbose=1,save_best_only=False,save_weights_only=False)
csv_log = tf.keras.callbacks.CSVLogger('{}.csv'.format(name_csv),append=True)
model.fit(tr_gen, epochs=num_epochs, validation_data = val_gen,verbose=2,batch_size = bt_sz,callbacks=[mc,csv_log])

/cc @XiaoquanKong can you check this?

@Bhack Thank you for your kind reminder, I will handle this.

Hi @dada_Lai, thank you for your bug report. I am trying to reproduce this bug on my computer, if I found the root cause or need your help, I will let you know.

Hi @dada_Lai, I have tried to reproduce this bug on Colab, but it works fine in my notebook. You can find my notebook at Google Colab . And, you can reload the model in a separate notebook at Google Colab. To run the second notebook, you need to copy the model produced by the first notebook to the workspace of the second notebook via mounting Google Drive. If you have any questions please let me know it.

Hi @XiaoquanKong ,

Just in case you didn’t notice the response I posted in another related question, I re-post the results I got as below.

I run the same code with the same testing data provided above on Google Colab, and it only shows a few minor warnings when saving:

WARNING:absl:Found untraced functions such as dense_1_layer_call_and_return_conditional_losses, dense_1_layer_call_fn, dense_1_layer_call_fn, dense_1_layer_call_and_return_conditional_losses, dense_1_layer_call_and_return_conditional_losses while saving (showing 5 of 15). These functions will not be directly callable after loading.

I am wondering if you ever test the codes with tensorflow GPU version (i.e. tensorflow-gpu=2.6.0) because this is the main difference between my tests running on Google Colab (tensorflow=2.6.0) and on my PC (tensorflow-gpu=2.6.0). So far I couldn’t run the code with tensorflow GPU version on Google Colab and will let you know once I fix it.

Thank you very much.

I see that @XiaoquanKong colab Is on a GPU runtime. What Is the problem?

Hi @Bhack,

Thank you for the following up. My question is the issue to re-train the model with CRFModelWrapper API still exists when I use the tensorflow-gpu package. In the beginning, I thought the example provided by @XiaoquanKong is only for the tensorflow, not for the tensorflow-gpu, but I found it seems there is no difference between them as the clarification. So, I am somewhat confused and wondering if CRFModelWrapper API can work properly in the tensorflow-gpu environment? Please let me know if there is anything not clear.

Thanks.

tensorflow-gpu Is only for older and now end of Life versions:

Got it.

Thank you all very much.