Self-supervised contrastive learning with SimSiam

Hi folks,

I am delighted to share my latest example on keras.io - Self-supervised contrastive learning with SimSiam:

https://keras.io/examples/vision/simsiam

This one introduces you to the world of self-supervised learning for computer vision and at the same time walks you through a decent self-supervised learning method (SimSiam) for the field.

While self-supervision has been pretty predominant in the NLP world since time immemorial but it was only in the past year these methods showed progress for the vision systems. We have evidence that these methods can often beat their supervised counterparts but only that their pre-training does not require any labeled data (but far more inductive priors).

Happy to address any feedback on the post.

5 Likes

Thank you for posting this tutorial! I’ve been trying to use it as a baseline for something I’m working on and I am having trouble reproducing the original paper’s results with the example. In the supplemental section D of the Simsiam paper, they report their results on Cifar10 after training for what I presume is 100 epochs. By carefully looking at your code and the paper they seem to be the same besides a few hyperparameter changes and the data augmentation. I added a true random cropping function to the code you provided and have tried training for up to 800 epochs. Despite all this, I never get a linear evaluation higher than 35%. Any idea what else might be missing in the tutorial?

2 Likes

Please note that Keras Examples are primarily meant for demonstrating workflows as noted here.

I added a true random cropping function to the code you provided

Not sure what you meant by true random cropping function.

Is your backbone architecture the same as the paper? The original one is ResNet-18. Even if that is the case, you’d need to ensure it’s following the same hyperparameters (like zero bias in the BatchNorm layers before the identity, etc.). Following what’s noted under the supplementary A. Implementation Details section is important here barring the CIFAR-10 specific changes.

1 Like

Thank you for your quick reply. I am well aware that tutorials are not meant to be benchmark implementations but I was surprised that it did not train usable representations at all.

By true random cropping I mean a random crop and resize as mentioned in the paper. This crops a random percentage of the original image, given as hyperparameters, and then resizes the cropped image to the original size. “tf.image.random_crop” as used in the tutorial, randomly crops an image at the size given, since we’re just cropping a 32x32 image to 32x32 it is effectively doing nothing. I thought that the lack of cropping was causing the tutorial code to not work, but fixing this did not improve the performance.

I did not match the backbone from the paper and their hyperparameters. I used what you used in the tutorial as it is very close. I’ll try changing the backbone and maybe it’ll help but I’d be very surprised if a slightly different backbone changes the linear transfer score from 35% to the 90% reported in the paper. It would also be pretty damning for SimSiam if it turned out the backbone was the primary contributor to benchmark success.

Thanks for any help. I’m just trying to get a clean, simple, implementation that performs roughly similar to the paper and hopefully find a few tweaks that could allow a person to get a usable trained model from the tutorial.

Wanted to reiterate one of the statements regarding these examples again:

NOTE THAT THIS COMMAND WILL ERROR OUT IF ANY CELLS TAKES TOO LONG TO EXECUTE. In that case, make your code lighter/faster. Remember that examples are meant to demonstrate workflows, not train state-of-the-art models. They should stay very lightweight.

So, in the case of this example, we get to 35% which is better than random chance (10%). But I do understand your point and I will try my best to suggest from my experience (that’s the whole purpose of having a discussion forum, isn’t it :)).

Random-resized crops. You are right. I should have included a note about this in the tutorial itself. I was a bit worried about the lines of code. But refer to this implementation and see if this works. It’s a bit different from PyTorch’s RandomResizedCrop() but it has resulted well in the other experiments I have performed in this area.

I agree here. But what happens is the relation between your dataset’s complexity and architectural complexity matters. In some methods like SimCLR, this effect is very pronounced. But in methods like Barlow Twins, it’s lesser. You can find an implementation of Barlow Twins here and within just 100 epochs of pre-training, I was able to get to 62.61% in terms of linear evaluation on CIFAR-10.

I appreciate this so much. The field is relatively new, so as contributors we wanted to get the workflow out as soon as possible without missing out on the important bits. So, I am here to suggest anything I can to better your results.

What you are currently running into (poor performance) is likely because of a phenomenon called representation collapse. This is when your backbone predicts the same output for a given image and self-supervised methods for vision easily latch into this problem. More on this later.

One thing to note this, for linear evaluation, often a separate configuration of hyperparameters and augmentation pipeline are incorporated. For SimSiam’s case, they are specified under the A. Implementation Details section. Unfortunately, they did not state these for CIFAR-10. So, this leaves us a fairground for further experimentation.

At this point, you’d be totally correct to think that self-supervision (for computer vision) is often sensitive to hyperparameters. But methods like Barlow Twins, VICReg help to eliminate that to some extent. They also gracefully mitigate the problem of representation collapse (DINO as well).

I hope this helps.

1 Like

I’ve not tested this personally but you can check:

Thank you for this reply. I’ll have time to experiment with this more this week. For context, I am running the code outside of a notebook on my own machine so I can train 100 epochs pretty quick. I train the model, save it, and then load into an eval script that uses the code from the tutorial to train on linear evaluation but allows me to use different hyperparameters. I have also made my own random crop and resize using TensorFlow’s sampled_distorted_bounding_box.

def random_crop_and_resize(image, size, area_range):
    # Crop randomly smaller and resize to size
    begin, shape, _ = tf.image.sample_distorted_bounding_box(
        tf.shape(image),
        [[[0.0, 0.0, 1.0, 1.0]]],
        area_range=area_range)
    image = tf.image.crop_to_bounding_box(image, begin[0], begin[1], shape[0], shape[1])
    image = tf.image.resize(image, (size, size))
    return image

Thanks for the Barlow Twins recommendation, the paper seemed neat but I didn’t realize that it is empirically less prone to representation collapse. Currently, I need SimSiam for some work I want to do extending it, so getting my own working version of it that I understand is important. I also believe that this can shed some light on how important different aspects of the core algorithm are vs training tricks.

I hadn’t realized from any of the papers that any augmentations were used during the linear evaluation training step. I will experiment with this a bit more. I’ll also inspect the output representations from training to see if they’ve collapsed to a fixed output. I’ll post here if the tuning is successful or if I get stuck again.

Thank you.

Thanks, I’m trying to build my own example right now with the minimal important parts so that’s why I started with the tutorial. But I’ll inspect this soon to see if there are any important parts the keras code is missing.

Sure. It will only make the knowledge interchange better.

I would also manually inspect the outputs produced by random_crop_and_resize() (in case you haven’t already).

From most of the papers, the training tricks are important. But I see your point.

Hi, could you figure it out? I am facing the same issue. The representations are collapsing.

@Arnab_Mondal check the suggestions I shared in my comments. They are likely gonna help with that. For SimSiam the following factors collectively contribute to preventing the representation collapse:

  • stop_gradient()
  • Augmentations
  • The specific use of Batch Norm
  • The use of an Autoencoder like predictor network

All of these are reflected in the initial example I posted barring augmentations and the exact number of dense units in the predictor.

As I mentioned previously in my comments, I am happy to help here with all the suggestions I can provide from my little experience of working with SSL. So, please keep this thread posted with anything you find.