Model checkpointing best practices when using `train_step()`

Subclassing tf.keras.Model and overrding its train_step() function give us the kind of flexibility we need to control our training loops. It allows for easily plugging in our favorite callbacks, and almost do all kinds of stuff that are readily available during

I am wondering how to use the ModelCheckpoint callback in this instance. Consider the following use-case (taken from here):

class SimSiam(tf.keras.Model):
    def __init__(self, encoder, predictor):
        super(SimSiam, self).__init__()
        self.encoder = encoder
        self.predictor = predictor
        self.loss_tracker = tf.keras.metrics.Mean(name="loss")

How do I set up a ModelCheckpoint callback for this one?


Can you use something like

Training checkpoints  |  TensorFlow Core ?

1 Like

Of course. But that somehow defeats the purpose of progressive disclosure of complexity IMO. I wanted to able to focus on my training loop and delegate rest of the things to the framework whenever possible.

And the link does not elaboratively suggest a workaround for the use-case I mentioned.

1 Like

This is more that what you are looking for as It Is relative also to write a custom callback in a custom train loop:

But also in your case if you manually populate the CallbackList

I think that you still need to trigger the epoch event in your custom training loop.

1 Like

I guess for that I might need to discard the train_step() override which I don’t want to do. I will study the links you shared and get back.

1 Like

Could you clarify the question? Why doesn’t it work to just pass the callback to .fit like you normally would?

Yes. If you’re writing your own training loop, you need to drive the callbacks section using callback list here:

But if you’er using .train_step and .fit, all the callbacks should be driven as normal, no?

1 Like

Callback list seems to be a good option. I will try it out.

If your subclassed model (where I am overriding train_step()) contains two or more models and if you are passing ModelCheckpoint callback while calling .fit() on the subclassed model the callback would get confused.

1 Like

:thinking:That shouldn’t be like that. Models are supposed to be nestable.


The problem here is that the callback is defaulting to saving the model in HDF5 format (which apparently requires that to call .fit to set the input shape, and we don’t call fit on the nested mopdels.).

Set save_weights_only=True to save in the tensorflow checkpoint format and then it works.

1 Like

Okay. Let me test-drive this on the following since it has two networks present in the subclassed model SimSiam:

But I think it should be ambiguous for the callback to determine which network (out of the two) it should serialize, though.

1 Like

My hunch was totally wrong it seems. It seems to work right off the bat with save_weights_only=True:


But I think it should be ambiguous for the callback to determine which network (out of the two) it should serialize, though.

It saves a checkpoint of the whole SimSiam Model, that captures both of the nested models.


That’s what I noticed. Sorry about the botheration.