Implementation detail of tf.keras.callbacks.ModelCheckpoint

I’m trying to subclass ModelCheckpoint and I’m curious about an implementation detail on distributed environments. When saving the checkpoint, the callback will call:

self._write_filepath = distributed_file_utils.write_filepath(file_path, self.model.distribute_strategy)

which tests if the model is being executed in a distributed environment, and, returns the original file_path for the “chief worker”, and a temporary file path for all other workers. The callback will then call:

distributed_file_utils.remove_temp_dir_with_filepath(self._write_filepath, self.model.distribute_strategy)

which deletes the temporary file_path for the non-chief workers, and leave the original file_path for the original worker.

My question is: why this roundabout way of doing things? Why not simply check whether the worker is the chief one, and save the checkpoint only if that is true?

Hi @Eddie

Welcome to the TensorFlow Forum!

Please share the minimal reproducible code to replicate and understand the issue. You can refer to this Distribute Training with Keras page how the ModelCheckpoint callback has used in model compilation. Thank you.