Feature Importance of Recommender Model

Does anyone know of code to determine feature importance of a TensorFlow Recommenders model?

I’ve seen the SHAP library from Scott Lundberg, but it’s not clear to me which kind of explainer to use and how to invoke it.

My model is similar to the MovieLens example, combining dense layers, side features, and multiple tasks.

1 Like

@Wei_Wei might be able to help here

1 Like

@rcauvin Were you able to use SHAP with tensorflow recommenders ?

ratings = tfds.load(‘movielens/100k-ratings’, split=“train”)
movies = tfds.load(‘movielens/100k-movies’, split=“train”)

Select the basic features.

ratings = ratings.map(lambda x: {
“movie_title”: x[“movie_title”],
“user_id”: x[“user_id”],
“user_rating”: x[“user_rating”],
})
movies = movies.map(lambda x: x[“movie_title”])

Randomly shuffle data and split between train and test.

tf.random.set_seed(42)
shuffled = ratings.shuffle(100_000, seed=42, reshuffle_each_iteration=False)
train = shuffled.take(80_000)
test = shuffled.skip(80_000).take(20_000)
movie_titles = movies.batch(1_000)
user_ids = ratings.batch(1_000_000).map(lambda x: x[“user_id”])
unique_movie_titles = np.unique(np.concatenate(list(movie_titles)))
unique_user_ids = np.unique(np.concatenate(list(user_ids)))

class MovielensModel(tfrs.models.Model):
def init(self, rating_weight: float, retrieval_weight: float) → None:

We take the loss weights in the constructor: this allows us to instantiate

several model objects with different loss weights.

super().init()
embedding_dimension = 32

User and movie models.

self.movie_model: tf.keras.layers.Layer = tf.keras.Sequential([
tf.keras.layers.StringLookup(
vocabulary=unique_movie_titles, mask_token=None),
tf.keras.layers.Embedding(len(unique_movie_titles) + 1, embedding_dimension)
])
self.user_model: tf.keras.layers.Layer = tf.keras.Sequential([
tf.keras.layers.StringLookup(
vocabulary=unique_user_ids, mask_token=None),
tf.keras.layers.Embedding(len(unique_user_ids) + 1, embedding_dimension)
])

A small model to take in user and movie embeddings and predict ratings.

We can make this as complicated as we want as long as we output a scalar

as our prediction.

self.rating_model = tf.keras.Sequential([
tf.keras.layers.Dense(256, activation=“relu”),
tf.keras.layers.Dense(128, activation=“relu”),
tf.keras.layers.Dense(1),
])

The tasks.

self.rating_task: tf.keras.layers.Layer = tfrs.tasks.Ranking(
loss=tf.keras.losses.MeanSquaredError(),
metrics=[tf.keras.metrics.RootMeanSquaredError()],
)
self.retrieval_task: tf.keras.layers.Layer = tfrs.tasks.Retrieval(
metrics=tfrs.metrics.FactorizedTopK(
candidates=movies.batch(128).map(self.movie_model)
)
)

The loss weights.

self.rating_weight = rating_weight
self.retrieval_weight = retrieval_weight
def call(self, features: Dict[Text, tf.Tensor]) → tf.Tensor:

We pick out the user features and pass them into the user model.

user_embeddings = self.user_model(features[“user_id”])

And pick out the movie features and pass them into the movie model.

movie_embeddings = self.movie_model(features[“movie_title”])

return (
user_embeddings,
movie_embeddings,

We apply the multi-layered rating model to a concatentation of

user and movie embeddings.

self.rating_model(
tf.concat([user_embeddings, movie_embeddings], axis=1)
),
)
def compute_loss(self, features: Dict[Text, tf.Tensor], training=False) → tf.Tensor:
ratings = features.pop(“user_rating”)
user_embeddings, movie_embeddings, rating_predictions = self(features)

We compute the loss for each task.

rating_loss = self.rating_task(
labels=ratings,
predictions=rating_predictions,
)
retrieval_loss = self.retrieval_task(user_embeddings, movie_embeddings)

And combine them using the loss weights.

return (self.rating_weight * rating_loss

  • self.retrieval_weight * retrieval_loss)

model = MovielensModel(rating_weight=1.0, retrieval_weight=0.0)
model.compile(optimizer=tf.keras.optimizers.Adagrad(0.1))

cached_train = train.shuffle(100_000).batch(8192).cache()
cached_test = test.batch(4096).cache()
train_np=np.stack(list(train))

model.fit(cached_train, epochs=3)
metrics = model.evaluate(cached_test, return_dict=True)

model = MovielensModel(rating_weight=0.0, retrieval_weight=1.0)
model.compile(optimizer=tf.keras.optimizers.Adagrad(0.1))

model.fit(cached_train, epochs=3)
metrics = model.evaluate(cached_test, return_dict=True)

train_np=np.stack(list(train))

trained_movie_embeddings, trained_user_embeddings, predicted_rating = model({
“user_id”: np.array([“42”]),
“movie_title”: np.array([“Dances with Wolves (1990)”])
})
print(“Predicted rating:”)
print(predicted_rating)

model = MovielensModel(rating_weight=1.0, retrieval_weight=0.0)
model.compile(optimizer=tf.keras.optimizers.Adagrad(0.1))

import shap
background=train_np[np.random.choice(train_np.shape[0],100,replace=False)]
explainer=shap.DeepExplainer(model,background)

: object of type ‘NoneType’ has no len()

explainer=shap.DeepExplainer((model.layers[0].input,model.layers[-1].output),background)

tf.Tensor([[3.402324]], shape=(1, 1), dtype=float32) --------------------------------------------------------------------------- AttributeError Traceback (most recent call last) in () 8 background=train_np[np.random.choice(train_np.shape[0],100,replace=False)] 9 #explainer=shap.DeepExplainer(model,background) —> 10 explainer=shap.DeepExplainer((model.layers[0].input,model.layers[-1].output),background) /usr/local/lib/python3.7/dist-packages/keras/engine/base_layer.py in output(self) 2167 “”" 2168 if not self._inbound_nodes: → 2169 raise AttributeError(‘Layer ’ + self.name + ’ has no inbound nodes.’) 2170 return self._get_node_attribute_at_index(0, ‘output_tensors’, ‘output’) 2171 AttributeError: Layer retrieval_2 has no inbound nodes.

I haven’t tried using shap.DeepExplainer, but I was able to get shap.KernelExplainer to work.

I found that the SHAP explainer expects the input to the model to be a uniformly-typed 2D numpy array (as stated here), whereas my model expects a dictionary of tensors. I had to create an adapter and pass it as the model to the explainer. The adapter’s predict function accepts a 2D numpy array as input, converts it to the input the TFRS model expects, and calls the TFRS model’s predict function.

Here is the adapter code:

# SHAP expects model functions to take a 2D numpy array as input,
# whereas our model.predict function expects a dictionary of
# tensors. So we define a class that adapts 2D numpy array input
# to a batched dictionary dataset.

class SHAPPredictor:
  
    def __init__(
        self,
        model: tfrs.models.Model,
        tensor_slice_specs: dict,
        target_name: str):
        
      self.model = model
      self.tensor_slice_specs = tensor_slice_specs
      self.target_name = target_name
        
    # Convert typed model inputs to uniformly-typed input SHAP expects.
    def convert_to_shap_input(
        self,
        ds: tf.data.Dataset,
        sample_size: int):
        
      sample_ds = ds.unbatch().take(sample_size)
      x_sample = pd.DataFrame(sample_ds.as_numpy_iterator()).drop(self.target_name, axis = 1)
      input_df = x_sample.applymap(lambda x: str(x) if isinstance(x, int) else x.decode("utf-8", "ignore"))
    
      return input_df
        
    # Convert uniformly-typed input to typed input the model expects.
    def convert_to_model_input(
        self,
        X: np.ndarray):
      
      num_columns = X.shape[1]
      rejiggered = [X[:, i] for i in range(num_columns)]
      tensor_slices = {
        name: tf.convert_to_tensor(rejiggered[index], dtype = dtype) for index, (name, dtype) in enumerate(tensor_slice_specs.items()) if index < num_columns
      }
      input_ds = tf.data.Dataset.from_tensor_slices(tensor_slices)
      # print(list(input_ds.batch(5).as_numpy_iterator()))
      
      return input_ds

    # Adapt input and invoke the model to make predictions.
    def predict(
        self,
        X: np.ndarray):
      
      input_ds = self.convert_to_model_input(X)
      # _, _, predictions = model.predict(input_ds.batch(50))
      predictions = self.model.predict(input_ds.batch(50))
    
      return predictions

Here is the code to instantiate the adapter, create the explainer, and determine the SHAP values:

import shap

train_sample_size = 10000
test_sample_size = 100
background_sample_size = 50
shap_sample_size = 100

print("Creating a predictor to adapt SHAP uniformly-typed input to typed input the model expects.", end = " ")
shap_predictor = SHAPPredictor(ranking_model, tensor_slice_specs, target_name = "rating")
print("Done.")

print(f"Getting {train_sample_size} records from train data as dataframe.", end = " ")
x_train = shap_predictor.convert_to_shap_input(cached_train_ds, sample_size = train_sample_size)
print("Done.")

print(f"Getting {test_sample_size} records from test data as dataframe.", end = " ")
x_test = shap_predictor.convert_to_shap_input(cached_test_ds, sample_size = test_sample_size)
print("Done.")

print(f"Sampling {background_sample_size} background records from {len(x_train)} train data records.", end = " ")
background = shap.sample(x_train, background_sample_size)
print("Done.")

print(f"Creating explainer.", end = " ")
explainer = shap.KernelExplainer(shap_predictor.predict, background)
print("Done.")

print("Using explainer to determine SHAP values.")
shap_values = explainer.shap_values(x_test, nsamples = shap_sample_size)

You’ll notice the adapter relies on tensor_slice_specs to map a uniformly-typed 2D numpy array to the input the model expects. The tensor_slice_specs is a dictionary of feature names and types, such as:

{‘user_id’: tf.string, ‘gender’: tf.string, ‘marital_status’: tf.string, ‘age_range’: tf.string, ‘occupation’: tf.string, ‘day_of_week’: tf.int64, ‘hour_of_day’: tf.int64, ‘item_id’: tf.string, ‘rating’: tf.int32}

Hopefully a similar approach will work for your case.