Pairs of ragged sequences as input for Tensorflow model

I have a model that takes pairs of sequences. The sequences are of variable length.

My model does this:

  1. embeds each item of both sequences
  2. sums embeddings of all items for each sequence
  3. calculates dot product of the sums

This is the minimal code illustrating the idea:

import tensorflow as tf

inputs = tf.keras.Input(shape=(2, None), ragged=True, dtype=tf.int32)

embedder = tf.keras.layers.Embedding(input_dim=16, output_dim=16)
x = embedder(inputs)
x = tf.reduce_sum(x, axis=-2)

v1 = x[:, 0, :]
v2 = x[:, 1, :]
outputs = tf.reduce_sum(tf.multiply(v1, v2), axis=1)

model = tf.keras.models.Model(inputs=[inputs], outputs=[outputs])
model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True))

xs = tf.ragged.constant([
    [[0, 1, 2], [3, 4]],
    [[2, 0], [5]],
])
ys = tf.constant([0, 1])
dataset = tf.data.Dataset.from_tensor_slices((xs, ys))

model.fit(dataset)

An error occurs:

TypeError: Exception encountered when calling layer 'tf.__operators__.ragged_getitem' (type SlicingOpLambda).
    
    Ragged __getitem__ expects a ragged_tensor.
    
    Call arguments received by layer 'tf.__operators__.ragged_getitem' (type SlicingOpLambda):
      • rt_input=tf.Tensor(shape=(None, 16), dtype=float32)
      • key=({'start': 'None', 'stop': 'None', 'step': 'None'}, '0', {'start': 'None', 'stop': 'None', 'step': 'None'})

My questions are these:

  • Is it possible to do x[:, 0, :] indexing without getting the error?
  • What is the most idiomatic way to implement the idea using Tensorflow 2 API?