How to add more than one feature to the ranking model without getting a dimension problem?

Hi All,

Following the tutorial here Using TensorFlow Recommenders with TFX I am trying to extend the ranking model to handle more than one feature.

The models train with a single feature but concatinating two features in one model as shown in the tutorial here Taking advantage of context features  |  TensorFlow Recommenders causes the batch size to double and the Ranking Model will not accept this.

Any suggestions please on how to have multiple user and movie features in the .Ranking model. Is this the correct approach? Will the .Ranking Model accept different batch sizes if initiatised differently?

CC: @Robert_Crowe

class UserModel(tf.keras.Model):

    def __init__(self, transform_output):
        self.transform_output = transform_output

    def _get_str_vocab_for_feature(self, feature_name):
        vocab_key = feature_name + '_vocab'
        return [b.decode() for b in self.transform_output.vocabulary_by_name(vocab_key)]

    def _get_embedding_for_str_input(self, inputs, feature_name):
        vocab = self._get_str_vocab_for_feature(feature_name)
        return tf.keras.Sequential([
            tf.keras.layers.Input(shape=(1,), name=feature_name, dtype=tf.string),
                vocabulary=vocab, mask_token=None),
                len(vocab) + 1, self.EMBEDDING_DIMENSION),

    def _get_model(self, inputs):
        return tf.concat([
            self._get_embedding_for_str_input(inputs, key)
            for key in inputs.keys()
        ], axis=1)

    def call(self, inputs):
        return self._get_model(inputs)
        ValueError: Dimensions must be equal, but are 8192 and 16384 for '{{node ranking/mean_squared_error/SquaredDifference}} = SquaredDifference[T=DT_FLOAT](Squeeze_1, Squeeze)' with input shapes: [8192,32], [16384,32].
    Call arguments received by layer "ranking" (type Ranking):
      • labels=tf.Tensor(shape=(16384, 32), dtype=float32)
      • predictions=tf.Tensor(shape=(8192, 32), dtype=float32)
      • sample_weight=None
      • training=False
      • compute_metrics=True