Recommended way to interpolate positional embeddings in ViT in TensorFlow

The usual practice to use a Vision Transformer model on an image having a different resolution than the training one is as follows. Say inferring on 480x480 images as opposed to 224x224 (training resolution).

The learned positional (or sin/cosine or relative positional bias) embeddings are interpolated to match the target resolution. While it’s trivial to implement this in code but in TensorFlow, it seems like a non-trivial one particularly when one tries to serialize the model with model.save().

The following notebook implements a Vision Transformer model including all the vanilla recipes introduced in [1] such as learned positional embedding, the use of class token, etc. It also has a function to interpolate the positional embeddings when the input resolution is different from the training one.

I’ve even tried to decorate the call() method with tf.function (see below) but it doesn’t help either.

@tf.function(
     input_signature=[tf.TensorSpec([None, None, None, 3], tf.float32)]
 )

Any workarounds?

Cc @ariG23498

have you already seen:

Model saving is the issue here (and not how to implement the interpolation part in TensorFlow).

Ok, you have not explained that the problem was related to saving/serialization in the description (it was only in the colab).
Also, if you see, their implementation save the model. I suppose without problem, but I have not tested it personally.

I’ve not seen your repo impl in detail but I think that you are in the same case as:

That’s a warning and I am aware of it.

Which one are you referring to?

Yup, sorry. Just edited the post description.

The same repo. There are some save and interpolate tests to check:

What is your problem exactly?

I don’t see too much things in the colab other then

WARNING:tensorflow:No training configuration found in save file, so the model was not compiled. Compile it manually.

As your model was not compiled why you don’t load with:
vit_dino_base_loaded = tf.keras.models.load_model("vit_dino_base", compile=False)

If you run the Colab with the decorator enabled, the model instantiation itself won’t run.

Without that, things work okay but here’s the problem. After the model is instantiated and is called on some inputs, the model sets its input shapes and is therefore unable to operate on inputs having different resolutions.

Sorry for not making it clear.

The input shape need to be defined for build/save so you cannot change the input as is after loading in your case.

You need to use a sort of transfer weights approach:

See also:

Thanks for the pointers. Appreciate your help.

Cc: @ariG23498