Multi-input TF model using TFRecord datasets

I’m trying to create a multi-input single output model in TensorFlow. I load the data from TFRecs using the get_batched_data fn.

def get_batched_dataset(filenames, batch_size): 
    
    dataset = (
        tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO)
        .map(parse_tfrecord_fn, num_parallel_calls=AUTO)
        .map(prepare_sample, num_parallel_calls=AUTO)
        .batch(batch_size)
    )
    return dataset

In the above fn, I do some preprocessing with the loaded data from TFRecs using the prepare_sample fn.

def prepare_sample(features):
    image = features['image']
    w = tf.shape(image)[0]
    h = tf.shape(image)[1]
         
    # some type of preprocessing/data augmentation/transforms    
    
    x = {'l_eye': l_eye, 'r_eye':r_eye, 'kps':kps}  #l_eye & r_eye are images, kps is numerical data
    y = out
    
    return x, y

Below is a very small version of how I’m trying to code my model architecture, just to get an idea.

class cnn_model(layers.Layer):
  def __init__(self, name='cnn-model'):
    super(cnn_model, self).__init__()
    self.conv = layers.Conv2D(32, kernel_size=7, strides=2) 

  def call(self, input_image):
    x = self.conv(input_image)
    return x

class num_model(layers.Layer):
  def __init__(self, name='num-model'):
    super(num_model, self).__init__()

    self.dense1 = layers.Dense(128)
    self.dense2 = layers.Dense(16)

  def call(self, input_keypoints):
    x = self.dense1(input_keypoints)
    x = self.dense2(x)  
    return x


class main_model(Model):
  def __init__(self, name='main-model'):
    super(main_model, self).__init__()

    self.cnn_model = cnn_model()
    self.num_model = num_model()
    
    self.dense1 = layers.Dense(8)
    self.dense2 = layers.Dense(2) 

  def call(self, input_l_r_kps):
    leftEye, rightEye, kps  = input_l_r_kps['l_eye'], input_l_r_kps['r_eye'], input_l_r_lms['kps']

    l_eye_feat = tf.reshape(self.cnn_model(leftEye), (3, 128*128))
    r_eye_feat = tf.reshape(self.cnn_model(rightEye), (3, 128*128))
    kp_feat = self.num_model(kps)
    
    combined_feat = tf.concat((l_eye_feat, r_eye_feat, lm_feat),1)
    
    x = self.dense1(combined_feat)
    x = self.dense2(x)
    
    return x

Now, the dataset returned by the get_batched_dataset fn is what I’ll be feeding into the Keras model.fit method.

train_dataset = get_batched_dataset('train.tfrec', batch_size)
valid_dataset = get_batched_dataset('valid.tfrec', batch_size)

model.fit(
    x=train_dataset,   
    batch_size=batch_size,
    epochs=1,  
    validation_data=valid_dataset,
    use_multiprocessing=False
)

Can you please guide me where I’m going wrong? Is it in the prepare_sample fn by returning x as a dict, or somewhere in the model code? I’m really new to TF and confused.

Any help appreciated!