Help understanding fine-tuning tutorial

Sorry for spamming down the forum, but I have problems understanding the Eager Few Shot OD Training TF2 tutorial.

For this part:

detection_model = model_builder.build(
      model_config=model_config, is_training=True)

# Set up object-based checkpoint restore --- RetinaNet has two prediction
# `heads` --- one for classification, the other for box regression.  We will
# restore the box regression head but initialize the classification head
# from scratch (we show the omission below by commenting out the line that
# we would add if we wanted to restore both heads)

fake_box_predictor = tf.compat.v2.train.Checkpoint(
    _base_tower_layers_for_heads=detection_model._box_predictor._base_tower_layers_for_heads,
    # _prediction_heads=detection_model._box_predictor._prediction_heads,
    #    (i.e., the classification head that we *will not* restore)
    _box_prediction_head=detection_model._box_predictor._box_prediction_head,
    )
fake_model = tf.compat.v2.train.Checkpoint(
          _feature_extractor=detection_model._feature_extractor,
          _box_predictor=fake_box_predictor)
ckpt = tf.compat.v2.train.Checkpoint(model=fake_model)
ckpt.restore(checkpoint_path).expect_partial()

# Run model through a dummy image so that variables are created
image, shapes = detection_model.preprocess(tf.zeros([1, 640, 640, 3]))
prediction_dict = detection_model.predict(image, shapes)
_ = detection_model.postprocess(prediction_dict, shapes)
print('Weights restored!')

I don’t see how we actually restore the weights? As far as I can understand we create a checkpoint called fake_model that loads features from the model itself (bare ssd_resnet50 architecture with no weights, expect for random initial values).
We run restore on the provided checkpoint, but this is not linked to the model (detection_model) that is going to be trained in any way? Hence, we call restore on a checkpoint that is not linked to the model we are going to train?
So the model (detection_model) does not contain any of the weights from the checkpoint file.

In my mind this should be:

fake_box_predictor = tf.compat.v2.train.Checkpoint(
    _base_tower_layers_for_heads=detection_model._box_predictor._base_tower_layers_for_heads,
    # _prediction_heads=detection_model._box_predictor._prediction_heads,
    #    (i.e., the classification head that we *will not* restore)
    _box_prediction_head=detection_model._box_predictor._box_prediction_head,
    )
fake_model = tf.compat.v2.train.Checkpoint(
          _feature_extractor=detection_model._feature_extractor,
          _box_predictor=fake_box_predictor,
          model=detection_model)

fake_model.restore(checkpoint_path).expect_partial()

Thanks for any help and clarification!

It took me sometime to understand this too! :slight_smile:

Think like this:

  • detection_model is loaded from a configuration with random weights
  • this structure is used as the base for fake_box_predictor and fake_model.
  • the weights are loaded on fake_model. detection_model is part of the fake_model so it’s weights will also be populated on the load.
  • finally, run a fake image over detection_model so that everything is structured properly

does it makes sense?

2 Likes

verified - Divvya Saxena

Thank you for your reply! I think I might understand now:

Since we set the feature extractor and box predictor to be detection_models feature extractor and box_predictor (detection_model._feature_extractor and box_predictor), the values for these weights in the detection_model get set to whatever values that are in the checkpoint for these specific weights?
And the other weights are still just random initial values since they are not provided as arguments to the checkpoint function, hence not restored?

Thanks!

1 Like