Transfer learning/fine tuning - importance of setting the base model to ”inference mode“?

I have a question regarding transfer learning and fine tuning. The TensorFlow tutorial on ”Transfer learning and fine tuning“ explains the importance of setting the base model to ”inference mode“ for transfer learning and fine tuning and explains that this is another concept that freezing/unfreezing layers (setting ”trainable“ to False/True).

When checking other online sources on transfer learning/fine tuning, I saw that no one of them sets the base model to ”inference mode“. They all use the base model in ”training mode“.

Therefore my question is, why is no one setting the base model to ”inference mode“ as suggested in the TensorFlow tutorial?
Are these just examples of using transfer learning/fine tuning in a wrong way? Or are the authors of the other tutorial aware of this, but setting the base model to ”training mode“ has some other advantages, that neither get mentioned in their tutorial nor in the tutorial from TensorFlow?

My understanding is that you need to set the base model to inference mode in order to keep the weights as they are and not wreck all parameters while you train the higher layers. The lower layers have been trained to recognise basic features that are most probably the same for the problem that you are trying to solve.
After fine tuning the higher level layers that you have put on top of the base model, you might set to trainable the very high level of the base model at a low learning rate in order to fine tune those layers as well.

2 Likes

There are different flavors of transfer learning:

  • You can use a pre-trained model to extract deep representations from your data (say images) and use those to train a simple classifier (for example).

  • There’s another setting where you completely unfreeze the pre-trained model parameters, attach a custom head to that pre-trained model, and then begin training. This is referred to as fine-tuning. This is done to adjust the pre-trained model parameters to the given task. Let’s say your model was pre-trained on Dataset A with image classification as the (pre-)training objective and now you’d like to fine-tune the model on Dataset B but with the same training objective of image classification. This is still transfer learning since Dataset B is assumed to be different.

  • You can also combine the above two approaches i.e.,:

    a. First train a simple classification model with pre-trained model parameters kept frozen (non-trainable).
    b: Then unfreeze a few layers of the pre-trained model and then run training again and repeat the process until the model generalizes well.

All of these are available approaches and there are no set rules as to what approach is bound to work best for a given transfer learning task. So, experimentation is your best friend here.

4 Likes

Thanks for the answers.

I will probably go with the approach first, that is described in the tutorial.

  1. Transfer learning
  • Setting the base model to inference mode
  • Freezing all layers
  1. Fine tuning the model trained in the previous step
  • Setting the base model to inference mode
  • Unfreezing layers

Are there indicators, when it would make sense to unfreeze layer at the beginning of a training to adapt an existing model to a new task?

A Model is also a Layer. This code:

model.trainable = False

is implemented by the Model class as (roughly) this:

for layer in layers:
    layer.trainable = False

This is the only Keras API that implements “inference mode” and “freezing and unfreezing” layers. There is no separate “inference mode” API in the Model class.

Perhaps you should check the „transfer learning guide“ (사전 학습된 ConvNet을 이용한 전이 학습  |  TensorFlow Core ) then, where the following code is coming from:


base_model = keras.applications.Xception(
    weights="imagenet",  # Load weights pre-trained on ImageNet.
    input_shape=(150, 150, 3),
    include_top=False,
)  # Do not include the ImageNet classifier at the top.

# Freeze the base_model
base_model.trainable = False

# Create new model on top
inputs = keras.Input(shape=(150, 150, 3))
x = data_augmentation(inputs)  # Apply random data augmentation

# Pre-trained Xception weights requires that input be scaled
# from (0, 255) to a range of (-1., +1.), the rescaling layer
# outputs: `(inputs * scale) + offset`
scale_layer = keras.layers.Rescaling(scale=1 / 127.5, offset=-1)
x = scale_layer(x)

# The base model contains batchnorm layers. We want to keep them in inference mode
# when we unfreeze the base model for fine-tuning, so we make sure that the
# base_model is running in inference mode here.
x = base_model(x, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dropout(0.2)(x)  # Regularize with dropout
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

model.summary()

With adding the following line for creating the new model from the base model, the base model is set to inference mode.


x = base_model(x, training=False)

Ah, you’re right! Wait, does it set the base_model here, or does it create a separate “inference mode” context for only the new tensor created by this call.

I am sorry, but I think that I do not understand the question.

The documentation says that the parameter training=False will set this value for each layer in the base model for the __call__ method:

layer.__call__() (which controls whether the layer should run its forward pass in inference mode or training mode)

Hi, I read the topic inference mode related to augmentation in the article you linked, since images are only to be augmented in training. (see screenshot).Do you mean this? Unfortunately, I do not know how to set the inference mode. I switched to the ImageDataGenerator and call them just on training data (flow_from_dataframe). By the way I’m still searching for the best document version of all this transfer learning and fine-tuning Topics. Best regards Sina Screenshot_2022-04-29-21-06-48-080_com.android.chrome|225x500

Hi Sina,

when calling predict the model will be run in inference mode. For using the inference mode for the base model when doing transfer learning and/or fine-tuning you will have to pass the parameter ˋtraining=Falseˋ when calling the base model. For more details please refer to the code example in the previous post in this thread.

1 Like

Thank you for your explanation :smiley: