I’ve built a VGG-16-style model but with three output classifier heads, each of a different size. I’m using the object oriented Keras API (inheriting from tf.keras.Model) and have a tf.Dataset loader that supplies the data but I can’t for the life of me figure out how to separate the input from the outputs. None of the various cases I’ve seen seem to handle my particular use case.
Here is how I build my dataset (image_ds). Note the inner map function that loads the image and also three categorical vectors for the three outputs:
paths_ds = tf.data.Dataset.from_tensor_slices(image_paths) paths_ds = paths_ds.shuffle(buffer_size = len(image_paths), reshuffle_each_iteration = True) # reshuffle each epoch image_ds = paths_ds.map( lambda path: ( self._load_image(path, crop, image_shape, augment), self._get_onehot_label(path, tf_hair_length_class_index_by_image, self.num_hair_length_classes), self._get_onehot_label(path, tf_hair_volume_class_index_by_image, self.num_hair_volume_classes), self._get_onehot_label(path, tf_hair_part_class_index_by_image, self.num_hair_part_classes), ), num_parallel_calls = tf.data.experimental.AUTOTUNE ) image_ds = image_ds.batch(batch_size) image_ds = image_ds.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
Now, this is all batched as expected. For example, if batch size is 16, each iteration gets me 4 tensors of shapes: (16,224,224,3), (16,4), (16,4), (16,5), exactly as desired.
But now, how do I pass these into Keras?
Inheriting from tf.keras.Model, I implement a call() method:
def call(self, inputs, training = False): x = inputs y1 = inputs y2 = inputs y3 = inputs #...
Note that I am computing losses manually in call() and adding them to the model (so I don’t need to specify any losses when compiling the model). I feel this is easier to follow then trying to specify the loss function outside, although I am open to doing it that way too.
And to build the model:
model.build( input_shape = [ (None, 224, 224, 3), # input_image: (num_batches, height_pixels, width_pixels, 3) (None, training_data.num_hair_length_classes), # (num_batches, num_hair_length_classes) (None, training_data.num_hair_volume_classes), # (num_batches, num_hair_volume_classes) (None, training_data.num_hair_part_classes), # (num_batches, num_hair_part_classes) ] )
Then I want to fit:
model.fit(callbacks = callbacks_list, x = training_data.image_ds, validation_data = validation_data.image_ds, epochs = options.epochs)
And it explodes.
I’ve tried all kinds of variations, including pre-fetching everything out of the dataset and into 4 separate arrays (x, y1, y2, y3), but I can’t find way to make this work! Surely this is not an uncommon case?
I see lots of examples for the sequential and functional APIs but not this one.
Any guidance would be appreciated.