Learning aggregators using neural networks in TensorFlow 2.0

I am trying to design the following neural network model using Tensorflow:

One of the model’s inputs is X, a list of n vectors of dimension 3. The second input to the model is Y, a list of n natural numbers in ascending order starting with 0. The model’s output is Z, a list of m vectors of dimension 3.

There are m unique numbers in Y representing the class of input vectors of dimension 3. The number of input vectors of different classes may differ.

The first layer of the model transforms each vector in X to a vector of dimension 2 and applies the ‘gelu’ activation function. The second layer performs ‘segment_sum’ to reduce n vectors of dimension 2 to m vectors of dimension 2 using Y. The third layer transforms m vectors of dimension 3, which is the model’s output.

I use cosine dissimilarity loss and Adam optimizer to train the model.

Here is the code that I wrote for this purpose:

import numpy as np
import tensorflow as tf
from tensorflow import keras

# Prepare the input and output data (example)
n = 10
m = 4
X = np.random.random((n, 3)).astype('float32')
Y = np.array([0, 0, 1, 1, 2, 2, 2, 3, 3, 3]).astype('int32')
Z = np.random.random((m, 3)).astype('float32')

class CustomModel(tf.keras.Model):
    def __init__(self):
        super(CustomModel, self).__init__()
        self.dense1 = keras.layers.Dense(2, activation='gelu')
        self.dense2 = keras.layers.Dense(3)

    def call(self, inputs):
        X, Y = inputs
        X = self.dense1(X)
        X = tf.math.segment_sum(X, Y)
        Z = self.dense2(X)
        return Z

model = CustomModel()

model.compile(loss=tf.keras.losses.CosineSimilarity(axis=1), optimizer=tf.keras.optimizers.Adam())

model.fit([X, Y], Z, epochs=10)

The model is essentially designed to learn an aggregation function. However, I get the following error:

Traceback (most recent call last):
  File "/home/nitesh/PycharmProjects1/pythonProject/research/reasoning_with_vectors/custom_model.py", line 31, in <module>
    model.fit([X, Y], Z, epochs=10)
  File "/home/nitesh/miniconda3/envs/relbert/lib/python3.10/site-packages/keras/utils/traceback_utils.py", line 70, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/home/nitesh/miniconda3/envs/relbert/lib/python3.10/site-packages/keras/engine/data_adapter.py", line 1852, in _check_data_cardinality
    raise ValueError(msg)
ValueError: Data cardinality is ambiguous:
  x sizes: 10, 10
  y sizes: 4
Make sure all arrays contain the same number of samples.

I tried a lot, but I didn’t find any way to encode the model using TensorFlow 2.0. Could anyone help? Thanks in advance.