[Model] TF2.0 Implementation of "GCViT: Global Context Vision Transformer"

Hello everyone,
As you may already know, NVIDIA has recently published (20 June 2022) its paper, GCViT: Global Context Vision Transformer which outperforms ConvNeXt and SwinTransformer.
I’ve implemented this model using TensorFlow 2.0 and created an open-source library gcvit-tf. I’ve also made a notebook explaining it. I hope this helps the community. :slight_smile:

I’m also planning to publish it to TFHub, it would be really helpful if I could get some directions regarding this…

Here are links to the project,

Features of gcvit-tf:

  • This library loads ImageNet weights from the official repo.
  • Also, it has timm like features such as forward_features, forward_head, and reset_classifier which might come in handy.
  • It can be used in both GPU and TPU.

Supported Models

The official codebase had some issue which has been fixed recently (27 July 2022). Here’s the result of ported weights on ImageNetV2-Test data,

Model Acc@1 Acc@5 #Params
GCViT-XXTiny 63 85 12M
GCViT-XTiny 66 87 20M
GCViT-Tiny 69 89 28M
GCViT-Small 69 89 51M
GCViT-Base 71 90 90M

Usage

Install Library

pip install gcvit

Load model using the following codes,

from gcvit import GCViTTiny
model = GCViTTiny(pretrain=True)

Simple code to check model’s prediction,

from skimage.data import chelsea
img = tf.keras.applications.imagenet_utils.preprocess_input(chelsea(), mode='torch') # Chelsea the cat
img = tf.image.resize(img, (224, 224))[None,] # resize & create batch
pred = model(img).numpy()
print(tf.keras.applications.imagenet_utils.decode_predictions(pred)[0])

Prediction:

[('n02124075', 'Egyptian_cat', 0.9194835),
('n02123045', 'tabby', 0.009686623), 
('n02123159', 'tiger_cat', 0.0061576385),
('n02127052', 'lynx', 0.0011503297), 
('n02883205', 'bow_tie', 0.00042479983)]

For feature extraction:

model = GCViTTiny(pretrain=True)  # when pretrain=True, num_classes must be 1000
model.reset_classifier(num_classes=0, head_act=None)
feature = model(img)
print(feature.shape)

Feature:

(None, 512)

For feature map:

model = GCViTTiny(pretrain=True)  # when pretrain=True, num_classes must be 1000
feature = model.forward_features(img)
print(feature.shape)

Feature map:

(None, 7, 7, 512)

Note: Official repo has some issues which had resulted in a performance drop. It got updated recently 27 July 2022. But still, one issue persists. Hence, ImageNet weights may get updated in the future.

2 Likes

Hi Awsaf, does your implementation of GC ViT include semantic image segmentation ? Is there a training example notebook that you can share ? Thank you!