Predictions from SimCLR in the Model Garden

Hi all,

I have been experimenting with the semisupervised learning algorithm SimCLR as implementation in the Model Garden. It is under official.vision.beta.projects.simclr. I am able to successfully pretrain and finetune the model on custom data and the summary metrics look extremely promising. However, I cannot figure out how to load a trained model and generate predictions on new data. I would really appreciate some help.

I think that one specific problem that I’m having is generating or loading the checkpoints generated during training. Let’s say I do a training run. I get logs of the validation metrics and the fully-trained model checkpoint. If I then load the checkpoint and compute performance metrics again on my validation data, I get vastly different numbers. Validation metrics go from very good to no better than chance. It seems to me that either the checkpoint doesn’t store all the weights that the model needs, or that I’m loading the checkpoint wrong. Can anyone comment?

I have tried a number of methods to save the full model to disk once it is finished training. None of them work. I would appreciate some insight here if anyone has any.

  • The model object has its own model.save() method. I believe this is inherited from Keras. This method fails with a complaint about the SimCLR resnet backbone being incompatible somehow with what Keras expects. On Cifar10 for example, I get a message that the model I’m trying to save has an input size of [32,32,6] which is incompatible with the expected shape of [32,32,3].
  • The TensorFlow Model Garden library has its own export method tfm.core.export_base.export(). This also fails, I believe with a complaint that the method wasn’t implemented by SimCLR.
  • I have tried tf.saved_model.save(), which also fails.

Thanks for any and all help.

@Austin Welcome to Tensorflow Forum!

It’s great to hear that you’ve successfully pretrain and fine-tuned the SimCLR model on your custom data and achieved promising summary metrics. However, it’s understandable that you’re facing challenges with loading the trained model and generating predictions on new data.

Let’s address your specific issues one at a time:

1. Discrepancy in validation metrics:

The large difference in validation metrics between training and loading the checkpoint suggests an issue with either the checkpoint loading process or the computation of performance metrics.

Possible causes and solutions:

  • Incomplete checkpoint loading: Ensure you’re loading the entire checkpoint and not just a subset of the weights. The Model Garden implementation might require specific steps for loading the complete model.
  • Different evaluation code: Verify that you’re using the same evaluation code for both training and loading the checkpoint. Differences in data pre-processing, evaluation metrics, or batching can lead to discrepancies.
  • Model corruption: Check for any errors or warnings during training or loading the checkpoint. Sometimes, unexpected errors might corrupt the model weights.

2. Saving the full model:

Saving the full model to disk requires specific methods depending on the implementation. Here are some possibilities:

  • Model Garden checkpointing: Check if the Model Garden provides a built-in method for saving the complete model. It might be a specific function or class method within the SimCLR implementation.
  • TensorFlow SavedModel: If the model utilizes TensorFlow, you can use the tf.saved_model.save function to save the entire model architecture and weights.
  • Keras Model serialization: If the model is built with Keras, you can use the model.save method to save the complete model to a file.

Additional tips for debugging:

  • Print model summary: After loading the checkpoint, print the model summary to ensure all layers and weights are loaded correctly.
  • Visualize predictions: Analyze the predictions on a small sample of your data to identify any potential issues with the model’s behavior.
  • Compare outputs: Compare the outputs (activations or logits) from training and loaded models on the same input data to identify any differences.
  • Community resources: Search online forums and communities for similar issues encountered by other users working with the Model Garden SimCLR implementation.

By investigating these possibilities and utilizing the debugging tips, you should be able to identify the root cause and successfully load your trained model for generating predictions on new data.

Let us know if this helps!