Transfer Learning - Could not find matching function from the SavedModel

Hello everyone,

I’m trying to use transfer learning on my own dataset structured as:

datatset
    healthy
    unhealthy

First, I loaded these images off disk using image_dataset_from_directory.
Then, as good practice I split the dataset on train (80%), validation (10%) and test(10%) when developing the model.

batch_size = 32
img_height = 244
img_width = 244

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
  image_path,
  validation_split=0.2,
  subset="training",
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size)

class_names = train_ds.class_names

val_ds = tf.keras.preprocessing.image_dataset_from_directory(
  image_path,
  validation_split=0.2,
  subset="validation",
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size)

val_batches = tf.data.experimental.cardinality(val_ds)
test_dataset = val_ds.take(val_batches // 5)
val_ds = val_ds.skip(val_batches // 5)

Then, when I’m going to train a simple model with Transfer Learning using the datasets we just prepared.

model = tf.keras.Sequential([
    hub.KerasLayer('https://tfhub.dev/google/cropnet/feature_vector/cassava_disease_V1/1'),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.Dense(2)
])

model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=tf.keras.optimizers.Adam(),
    metrics=['accuracy']
)

log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)

model.fit(train_ds, epochs=10, validation_data=val_ds, callbacks=[tensorboard_callback])

I’m getting the below error:

ValueError: Could not find matching function to call loaded from the SavedModel. Got:
Positional arguments (3 total):
* Tensor(“image:0”, shape=(None, 244, 244, 3), dtype=float32)
* False
* 0.99
Keyword arguments: {}

Expected these arguments to match one of the following 4 option(s):

Option 1:
  Positional arguments (3 total):
    * TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name='image')
    * False
    * 0.99
  Keyword arguments: {}

Option 2:
  Positional arguments (3 total):
    * TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name='image')
    * False
    * TensorSpec(shape=(), dtype=tf.float32, name='batch_norm_momentum')
  Keyword arguments: {}

Option 3:
  Positional arguments (3 total):
    * TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name='image')
    * True
    * 0.99
  Keyword arguments: {}

Option 4:
  Positional arguments (3 total):
    * TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name='image')
    * True
    * TensorSpec(shape=(), dtype=tf.float32, name='batch_norm_momentum')
  Keyword arguments: {}
1 Like

@Yuqi_Li might have an insight on this

1 Like

Please use tf.keras.losses.CategoricalCrossentropy instead of tf.keras.losses.SparseCategoricalCrossentropy.

2 Likes

I changed from tf.keras.losses.SparseCategoricalCrossentropy to tf.keras.losses.CategoricalCrossentropy but it returns the same error.

1 Like

cropnet/feature_vector/cassava_disease_V1

Expects a float input tensor of shape [batch size, 224, 224, 3]
Take a look at this link

Modify your input image value as
img_height = 224
img_width = 224

Complete working sample code

import os
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub
print("TF version:", tf.__version__)
print("Hub version:", hub.__version__) 

IMAGE_SIZE = (224, 224)
batch_size = 32
img_height = 224
img_width = 224

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
  data_dir,
  validation_split=0.2,
  subset="training",
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size)

class_names = train_ds.class_names

val_ds = tf.keras.preprocessing.image_dataset_from_directory(
  data_dir,
  validation_split=0.2,
  subset="validation",
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size)

val_batches = tf.data.experimental.cardinality(val_ds)
test_dataset = val_ds.take(val_batches // 5)

model = tf.keras.Sequential([
    # Explicitly define the input shape so the model can be properly
    # loaded by the TFLiteConverter
    #tf.keras.layers.InputLayer(input_shape=IMAGE_SIZE + (3,)),
    hub.KerasLayer('https://tfhub.dev/google/cropnet/feature_vector/cassava_disease_V1/1'),
    tf.keras.layers.Dropout(rate=0.2),
    tf.keras.layers.Dense(len(class_names),
                          kernel_regularizer=tf.keras.regularizers.l2(0.0001))
])
model.build((None,)+IMAGE_SIZE+(3,))
model.summary()

model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=tf.keras.optimizers.Adam(),
    metrics=['accuracy']
)
model.fit(train_ds, epochs=10, validation_data=val_ds)
3 Likes

This line solve the issue!
Thank you