Keras model and tf.Dataset with single input type and multiple outputs


I’ve built a VGG-16-style model but with three output classifier heads, each of a different size. I’m using the object oriented Keras API (inheriting from tf.keras.Model) and have a tf.Dataset loader that supplies the data but I can’t for the life of me figure out how to separate the input from the outputs. None of the various cases I’ve seen seem to handle my particular use case.

Here is how I build my dataset (image_ds). Note the inner map function that loads the image and also three categorical vectors for the three outputs:

paths_ds =
    paths_ds = paths_ds.shuffle(buffer_size = len(image_paths), reshuffle_each_iteration = True)  # reshuffle each epoch
    image_ds =
      lambda path: (
        self._load_image(path, crop, image_shape, augment),
        self._get_onehot_label(path, tf_hair_length_class_index_by_image, self.num_hair_length_classes),
        self._get_onehot_label(path, tf_hair_volume_class_index_by_image, self.num_hair_volume_classes),
        self._get_onehot_label(path, tf_hair_part_class_index_by_image, self.num_hair_part_classes),
      num_parallel_calls =
    image_ds = image_ds.batch(batch_size)
    image_ds = image_ds.prefetch(

Now, this is all batched as expected. For example, if batch size is 16, each iteration gets me 4 tensors of shapes: (16,224,224,3), (16,4), (16,4), (16,5), exactly as desired.

But now, how do I pass these into Keras?

Inheriting from tf.keras.Model, I implement a call() method:

def call(self, inputs, training = False):
  x = inputs[0]
  y1 = inputs[1]
  y2 = inputs[2]
  y3 = inputs[3]

Note that I am computing losses manually in call() and adding them to the model (so I don’t need to specify any losses when compiling the model). I feel this is easier to follow then trying to specify the loss function outside, although I am open to doing it that way too.

And to build the model:
    input_shape = [
      (None, 224, 224, 3),  # input_image: (num_batches, height_pixels, width_pixels, 3)
      (None, training_data.num_hair_length_classes),  # (num_batches, num_hair_length_classes)
      (None, training_data.num_hair_volume_classes),  # (num_batches, num_hair_volume_classes)
      (None, training_data.num_hair_part_classes),    # (num_batches, num_hair_part_classes)

Then I want to fit: = callbacks_list, x = training_data.image_ds, validation_data = validation_data.image_ds, epochs = options.epochs)

And it explodes.

I’ve tried all kinds of variations, including pre-fetching everything out of the dataset and into 4 separate arrays (x, y1, y2, y3), but I can’t find way to make this work! Surely this is not an uncommon case?

I see lots of examples for the sequential and functional APIs but not this one.

Any guidance would be appreciated.

Thank you,


Would you mind posting the the full error message you see when “it explodes”?


ValueError: in user code:

    C:\Users\Bart\anaconda3\envs\avatarize\lib\site-packages\tensorflow\python\keras\engine\ train_function  *
        return step_function(self, iterator)
    C:\Users\Bart\anaconda3\envs\avatarize\lib\site-packages\tensorflow\python\keras\engine\ step_function  **
        outputs =, args=(data,))
    C:\Users\Bart\anaconda3\envs\avatarize\lib\site-packages\tensorflow\python\distribute\ run
        return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
    C:\Users\Bart\anaconda3\envs\avatarize\lib\site-packages\tensorflow\python\distribute\ call_for_each_replica
        return self._call_for_each_replica(fn, args, kwargs)
    C:\Users\Bart\anaconda3\envs\avatarize\lib\site-packages\tensorflow\python\distribute\ _call_for_each_replica
        return fn(*args, **kwargs)
    C:\Users\Bart\anaconda3\envs\avatarize\lib\site-packages\tensorflow\python\keras\engine\ run_step  **
        outputs = model.train_step(data)
    C:\Users\Bart\anaconda3\envs\avatarize\lib\site-packages\tensorflow\python\keras\engine\ train_step
        x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
    C:\Users\Bart\anaconda3\envs\avatarize\lib\site-packages\tensorflow\python\keras\engine\ unpack_x_y_sample_weight
        raise ValueError(error_msg)

    ValueError: Data is expected to be in format `x`, `(x,)`, `(x, y)`, or `(x, y, sample_weight)`, found: (<tf.Tensor 'IteratorGetNext:0' shape=(None, 224, 224, 3) dtype=float32>, <tf.Tensor 'IteratorGetNext:1' shape=(None, 4) dtype=float32>, <tf.Tensor 'IteratorGetNext:2' shape=(None, 4) dtype=float32>, <tf.Tensor 'IteratorGetNext:3' shape=(None, 5) dtype=float32>)

So some more context:

I might be misunderstanding the intended use case for inheriting from tf.keras.Model but I like it because it seems to afford the most flexibility. The call() method only provides an “input” argument, however, so that’s the first point of confusion. How are outputs intended to be handled?

In general, I’m wondering how one would specify multiple inputs and multiple outputs simultaneously. The fit() method is great because it allows a validation dataset to be passed in and should be the least fuss.

Instead, here is how I “solved” my problem. I wrote my own training loop using train_on_batch:

for epoch in range(1, 1 + options.epochs):
    print("Epoch %d/%d" % (epoch, options.epochs))
    for inputs in training_data.dataset:
      losses = model.train_on_batch(x = inputs, return_dict = True)
    evaluate(model = model, eval_data = validation_data, print_accuracy = True)

Note that I have to call a separate evaluate() function. This solution is flexible but less than ideal. I can no longer use Keras callbacks for example (e.g., ModelCheckpoint and ReduceLROnPlateau). I can’t track metrics like val_acc and so forth.

I compute losses manually inside my model. Here is the complete call function:

  def call(self, inputs, training = False):
    input_image = inputs[0]         # (num_batches, height_pixels, width_pixels, 3)

    # Backbone
    y = self._block1_conv1(input_image)
    y = self._block1_conv2(y)
    y = self._block1_maxpool(y)

    y = self._block2_conv1(y)
    y = self._block2_conv2(y)
    y = self._block2_maxpool(y)

    y = self._block3_conv1(y)
    y = self._block3_conv2(y)
    y = self._block3_conv3(y)
    y = self._block3_maxpool(y)

    y = self._block4_conv1(y)
    y = self._block4_conv2(y)
    y = self._block4_conv3(y)
    y = self._block4_maxpool(y)

    y = self._block5_conv1(y)
    y = self._block5_conv2(y)
    y = self._block5_conv3(y)
    y = self._block5_maxpool(y)

    y = self._flatten(y)

    # Hair length head
    yhl = self._hair_length_fc1(y)
    yhl = self._hair_length_do1(yhl)
    yhl = self._hair_length_fc2(yhl)
    yhl = self._hair_length_do2(yhl)
    yhl = self._hair_length_predictions(yhl)

    # Hair volume head
    yhv = self._hair_volume_fc1(y)
    yhv = self._hair_volume_do1(yhv)
    yhv = self._hair_volume_fc2(yhv)
    yhv = self._hair_volume_do2(yhv)
    yhv = self._hair_volume_predictions(yhv)

    # Hair part head
    yhp = self._hair_part_fc1(y)
    yhp = self._hair_part_do1(yhp)
    yhp = self._hair_part_fc2(yhp)
    yhp = self._hair_part_do2(yhp)
    yhp = self._hair_part_predictions(yhp)

    # Losses
    if training:
      y_true_hair_length = inputs[1]  # (num_batches, hair_length_classes)
      y_true_hair_volume = inputs[2]  # (num_batches, hair_volume_classes)
      y_true_hair_part = inputs[3]    # (num_batches, hair_part_classes)
      hair_length_loss = self._cross_entropy_loss(y_predicted = yhl, y_true = y_true_hair_length)
      hair_volume_loss = self._cross_entropy_loss(y_predicted = yhv, y_true = y_true_hair_volume)
      hair_part_loss = self._cross_entropy_loss(y_predicted = yhp, y_true = y_true_hair_part)
      self.add_metric(hair_length_loss, name = "hair_length_loss")
      self.add_metric(hair_volume_loss, name = "hair_volume_loss")
      self.add_metric(hair_part_loss, name = "hair_part_loss")
      self.add_metric(hair_length_loss + hair_volume_loss + hair_part_loss, name = "loss")
      # Losses cannot be computed during inference and should be ignored
      hair_length_loss = float("inf")
      hair_volume_loss = float("inf")
      hair_part_loss = float("inf")

    # Return outputs
    if training:
      return [
      return [

You can see that “input” now contains all 4 of my input tensors (one actual input and three ground truth output values). I unpack them manually and compute my loss there as well when in training mode. But I don’t think this is the “correct” way to do it and I lose a lot of functionality by doing so (e.g., inability to use the fit() method and automatic handling of callbacks, etc.)