Train a Vision Transformer on small datasets

ViTs are data hungry, pretraining a ViT on a large-sized dataset like JFT300M and fine-tuning it on medium-sized datasets (like ImageNet) is the only way to beat state-of-the-art Convolutional Neural Network models.

The self-attention layer of ViT lacks locality inductive bias (the notion that image pixels are locally correlated and that their correlation maps are translation-invariant). This is the reason why ViTs need more data. On the other hand, CNNs look at images through spatial sliding windows, which helps them get better results with smaller datasets.

In my latest keras example I minimally implement the academic paper Vision Transformer for Small-Size Datasets. Here the authors set out to tackle the problem of locality inductive bias in ViTs by introducing two novel ideas:

  • Shifted Patch Tokenization (SPT): A tokenization scheme which allows for a greater receptive field for the transformer.

  • Locality Self Attention (LSA): A tweaked version of the multi head self attention mechanism. Applying a diagonal mask and a learnable temperature quotient to the regular self attention, we get our LSA. Inheriting the tf.keras.layers.MultiHeadAttention and tweaking the API was my greatest win while implementing LSA.

Tutorial: Train a Vision Transformer on small datasets

3 Likes

Amazing :slightly_smiling_face: :+1: Thank you

1 Like