Why doesn’t model.build() set input shape?

I’ve got a subclassed model which I’m trying to run model.save() on, but I get the following error:

Model <main.ReId object at 0x7fe554658590> cannot be saved because the input shapes have not been set. Usually, input shapes are automatically determined from calling .fit() or .predict(). To manually set the shapes, call model.build(input_shape).

This is despite explicitly calling model.build(input_shape=(256,256,3)) to set the input shape.

I’ve realised that this only happens when I use my custom BatchDataset. When I run model.fit() on a dataset generated by an ImageDataGenerator the model saves normally.

The full code is available at the link below:
https://vehiclereidjupyternotebook.s3.eu-west-2.amazonaws.com/broken_saving_tf.html

Do you have already tried with: