Retrieval Model with Context Features Doesn't Learn

I’ve built a recommender model with a retrieval task and context features for both the query and candidate towers, and it doesn’t learn. The val_factorized_top_k/top_100_categorical_accuracy, val_factorized_top_k/top_50_categorical_accuracy, etc. barely budge. (If I build the same model with a ranking task, it learns nicely and ends up with a solid ROC-AUC.) (The target label is binary 0 or 1 for whether the user clicked the item when she was prompted to do so.)

What am I doing wrong with the retrieval model?

In the main model’s constructor, I create the retrieval task as follows:

retrieval_candidates = unique_candidate_ds.batch(128).map(self.candidate_model)
retrieval_metrics = tfrs.metrics.FactorizedTopK(candidates = retrieval_candidates)
self.retrieval_task: tf.keras.layers.Layer = tfrs.tasks.Retrieval(
  metrics = retrieval_metrics,
  batch_metrics = [tf.keras.metrics.AUC(from_logits = True, name = "retrieval_auc")])

Note that unique_candidate_ds is a dataset of item IDs as well as associated context features that are inputs into the candidate tower.

For invoking the candidate model, I have:

  def call(
    features: Dict[str, tf.Tensor]) -> tf.Tensor:
    embedding_outputs = []
    for name, embedding in self.embeddings.items():
    output = self.dense_output(tf.concat(embedding_outputs, axis = 1))
    return output

(The query model is similar, as it also has context features and a dense output layer.)

In the main model, I have:

  def compute_loss(
    features: Dict[str, tf.Tensor],
    training: bool = False):
    labels = features.pop("rating")
    query_output = self.query_model(features)
    candidate_output = self.candidate_model(features)
    retrieval_loss = self.retrieval_task(query_output, candidate_output, compute_metrics = not training)

    return retrieval_loss

I compile the model with:

model_r.compile(optimizer = tf.keras.optimizers.Adam(0.001))

Why doesn’t the model learn, but if I change it to use a ranking task, it learns nicely?


Here is the loss history after five training epochs. The loss increases slightly from 8.225164 to 8.225174 during training:

And the top K categorical accuracies stay flat:

The losses and categorical accuracies above are validation metrics in the training history. When I plot the training metrics (not the validation ones), I see that the losses decrease nicely, and the top k categorical accuracies increase nicely.

My call to is:

retrieval_model_history =
    validation_data = cached_positive_test_ds,
    validation_freq = 1,
    epochs = num_epochs,
    verbose = 1)

I’m assuming here that I should be passing only the positive interactions (ones where the user clicked the item when prompted) for both the training dataset and the validation dataset.