Vanishing gradients while finetuning efficient net V2 XL

I have an aerial imagery dataset(around 250k images for training) and am trying to finetune the efficient net v2 xl architecture trained on imagenet 21k dataset. However, as the architecture is not directly available with keras I am using the pretrained backbone from tensorflow hub. Below is the implementation of my model.

inputs = layers.Input(shape=(512, 512, 3))
model = hub.KerasLayer("https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_xl/feature_vector/2", trainable=True)
x = model(inputs)

# multi - task learning
# 3 outputs using same backbone
output1 = layers.Dense(3, activation='softmax')(x)
output2 = layers.Dense(6, activation='softmax')(x)
output3 = layers.Dense(5, activation='softmax')(x)

model = tf.keras.Model(inputs, [output1, output2, output3])

lr_schedule = tf.keras.optimizers.schedules.CosineDecay(0.001, 10000)
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule, clipvalue=2)

sparse_categorical_crossentropy loss is being used to optimize the weights.

I am using 220k images to train the model and the batch size used is 8. At the end of the first epoch I see that the weights are becoming NaN and slowly the loss becomes NaN as well

I believe that the gradients are vanishing and any help to prevent this issue is highly appreciated.

Hi @Dunna_SuryaNarayana,

As your dataset is quite large, it is possible that it may contain some corrupted or invalid images. It is important to check for these invalid images before you start training your model and to follow the below steps.

  1. Data Preprocessing:Ensure that your input data is properly preprocessed. In your case, for aerial imagery, consider normalizing the pixel values between 0 and 1.
  2. Learning Rate:You are using a learning rate schedule with a cosine decay, which is generally effective.Can you try reducing the initial learning rate even further, such as setting it to 0.0001, and see if it helps stabilize the training?
  3. Gradient Clipping: Setting clipvalue=1.0 and seeing if it helps stabilize the training.
  4. Batch Size:Try increasing the batch size to a larger value, such as 16 or 32, to stabilize the training process.

Please try all the above steps and let us know if they solve the NaN problem.

Thanks.

2 Likes

I was trying to fine tune gpt2 loaded to a gpt model created from scratch and I get nan values after one or two epochs and after removing the clipping the problem disappears that is good :slight_smile: but that left me with a headache :(, why clipping could be bad especially in fine tuning ?

because I have used the same architecture to train a little gpt model and nothing bad apparently.

any intuition or evidence why did this happened ?