ValueError: as_list() is not defined on an unknown TensorShape

Hi, I’m trying to use tf.data.Dataset.list_files to load .tiff images and infer their labels from their names.

I use the following code but stumbled upon a strange issue, as described bellow:

import os
import datetime as dt
import numpy as np
import pathlib
from pathlib import Path
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import cv2


DATA_PATH = Path('PATH TO DATA')
BATCH_SIZE = 32
INPUT_IMAGE_SHAPE = (128, 128, 1)
CROP_SHAPE = INPUT_IMAGE_SHAPE
CENTRAL_CROP_PROP = .7
BRIGHTNESS_DELTA = 0.1
CONTRAST = (0.4, 0.6)

class ConvModel(keras.Model):
    def __init__(self, input_shape):
        super().__init__()
        self.input_image_shape = input_shape
        self.model = keras.Sequential([
            layers.Input(shape=input_shape),
            layers.Conv2D(32, 3),
            layers.BatchNormalization(),
            layers.ReLU(),
            layers.MaxPool2D(),
            layers.Conv2D(64, 5),
            layers.BatchNormalization(),
            layers.ReLU(),
            layers.MaxPool2D(),
            layers.Conv2D(128, 3, kernel_regularizer=keras.regularizers.l2(0.01)),
            layers.BatchNormalization(),
            layers.ReLU(),
            layers.Flatten(),
            layers.Dense(64, activation='relu', kernel_regularizer=keras.regularizers.l2(0.01)),
            layers.Dropout(0.5),
            layers.Dense(10)
        ])

    def call(self, inputs):
        return self.model(inputs)


def preprocessing_func(image):
    img = tf.image.central_crop(image, CENTRAL_CROP_PROP)
    if img.shape[2] == 3:
        img = tf.image.rgb_to_grayscale(img)
    return img


def augment(image):
    img = tf.image.random_crop(image, size=CROP_SHAPE)  # Slices a shape size portion out of value at a uniformly chosen offset. Requires value.shape >= size.
    img = tf.image.random_brightness(img, max_delta=BRIGHTNESS_DELTA)  # Equivalent to adjust_brightness() using a delta randomly picked in the interval [-max_delta, max_delta)
    img = tf.image.random_contrast(img, lower=CONTRAST[0], upper=CONTRAST[1])  # Equivalent to adjust_contrast() but uses a contrast_factor randomly picked in the interval [lower, upper).
    img = tf.image.random_flip_left_right(img)
    img = tf.image.random_flip_up_down(img)

    return img


def load_image(image_file):
    # 1) Decode the path
    image_file = image_file.decode('utf-8')

    # 2) Read the image
    img = cv2.imread(image_file)
    if len(img.shape) < 3:
        img = np.expand_dims(img, axis=-1)
    img = preprocessing_func(image=img)
    img = augment(img)
    img = tf.cast(img, tf.float32)
    img.set_shape(INPUT_IMAGE_SHAPE)
    # 3) Get the label
    label = tf.strings.split(image_file, "\\")[-1]
    label = tf.strings.substr(label, pos=0, len=1)
    label = tf.strings.to_number(label, out_type=tf.float32)
    label = tf.cast(label, tf.float32)
    label.set_shape([])
    return img, label

def _fixup_shape(images, labels):
    images.set_shape(INPUT_IMAGE_SHAPE)
    labels.set_shape([])
    return images, labels

if __name__=='__main__':
    train_ds = tf.data.Dataset.list_files(str(DATA_PATH / '*.tiff'))
    train_ds = train_ds.map(lambda x: tf.numpy_function(load_image, [x], (tf.float32, tf.float32)))
    # train_ds = train_ds.map(_fixup_shape)
    train_ds = train_ds.batch(BATCH_SIZE)
    train_ds = train_ds.shuffle(buffer_size=1000)
    train_ds = train_ds.prefetch(tf.data.AUTOTUNE)
    train_ds = train_ds.repeat()

    model = ConvModel(input_shape=INPUT_IMAGE_SHAPE)
    model.compile(
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        optimizer=keras.optimizers.Adam(learning_rate=3e-4),
        metrics=['accuracy']
    )

    train_log_dir = f'./logs/{dt.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}/train_data'
    callbacks = [
        keras.callbacks.TensorBoard(
            log_dir=train_log_dir,
            write_images=True
        )
    ]

    model.fit(
        train_ds,
        batch_size=32,
        steps_per_epoch=10,
        epochs=10,
        callbacks=callbacks
    )

While I try to run it it throws up an error :

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-213-b1f3d317135b> in <module>
----> 1 model.fit(
      2     train_ds,
      3     batch_size=32,
      4     steps_per_epoch=10,
      5     epochs=10,

~\anaconda3\lib\site-packages\keras\utils\traceback_utils.py in error_handler(*args, **kwargs)
     65     except Exception as e:  # pylint: disable=broad-except
     66       filtered_tb = _process_traceback_frames(e.__traceback__)
---> 67       raise e.with_traceback(filtered_tb) from None
     68     finally:
     69       del filtered_tb

~\anaconda3\lib\site-packages\tensorflow\python\framework\func_graph.py in autograph_handler(*args, **kwargs)
   1127           except Exception as e:  # pylint:disable=broad-except
   1128             if hasattr(e, "ag_error_metadata"):
-> 1129               raise e.ag_error_metadata.to_exception(e)
   1130             else:
   1131               raise

ValueError: in user code:

    File "C:\Users\mchls\anaconda3\lib\site-packages\keras\engine\training.py", line 878, in train_function  *
        return step_function(self, iterator)
    File "C:\Users\mchls\anaconda3\lib\site-packages\keras\engine\training.py", line 867, in step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "C:\Users\mchls\anaconda3\lib\site-packages\keras\engine\training.py", line 860, in run_step  **
        outputs = model.train_step(data)
    File "C:\Users\mchls\anaconda3\lib\site-packages\keras\engine\training.py", line 817, in train_step
        self.compiled_metrics.update_state(y, y_pred, sample_weight)
    File "C:\Users\mchls\anaconda3\lib\site-packages\keras\engine\compile_utils.py", line 439, in update_state
        self.build(y_pred, y_true)
    File "C:\Users\mchls\anaconda3\lib\site-packages\keras\engine\compile_utils.py", line 359, in build
        self._metrics = tf.__internal__.nest.map_structure_up_to(y_pred, self._get_metric_objects,
    File "C:\Users\mchls\anaconda3\lib\site-packages\keras\engine\compile_utils.py", line 485, in _get_metric_objects
        return [self._get_metric_object(m, y_t, y_p) for m in metrics]
    File "C:\Users\mchls\anaconda3\lib\site-packages\keras\engine\compile_utils.py", line 485, in <listcomp>
        return [self._get_metric_object(m, y_t, y_p) for m in metrics]
    File "C:\Users\mchls\anaconda3\lib\site-packages\keras\engine\compile_utils.py", line 506, in _get_metric_object
        y_t_rank = len(y_t.shape.as_list())

    ValueError: as_list() is not defined on an unknown TensorShape.

though manually running X.shape.as_list() and y.shape.as_list() works, as shown below:

X, y = next(iter(train_ds))
X.shape.as_list(), y.shape.as_list()
[OUT] ([16, 128, 128, 1], [16])

This issue is fixed by manually mapping the following function on the dataset by train_ds = train_ds.map(_fixup_shape).batch(BATCH_SIZE):

def _fixup_shape(images, labels):
    images.set_shape([128, 128, 1])
    labels.set_shape([]) # I have 19 classes
    # weights.set_shape([None])
    return images, labels

if __name__=='__main__':
    train_ds = tf.data.Dataset.list_files(str(DATA_PATH / '*.tiff'))
    train_ds = train_ds.map(lambda x: tf.numpy_function(load_image, [x], (tf.float32, tf.float32)))
    train_ds = train_ds.map(_fixup_shape)
    train_ds = train_ds.batch(BATCH_SIZE)
    train_ds = train_ds.shuffle(buffer_size=1000)
    train_ds = train_ds.prefetch(tf.data.AUTOTUNE)
    train_ds = train_ds.repeat()

as described in this GitHub thread there is a problem while using the tf.data.Dataset.list_files, which can be solved by mapping a fix function on the dataset.

Is it a bug in TF 2.6.1 or is it an expected behavior?

Thanks

This is expected behavior.

Normally TensorFlow can handle shapes with unknown dimensions. It really can’t handle shapes with an unknown number of dimensions.

tf.data.Dataset.from_generator, and tf.py_function get results from python code, those could be anything. You need to specify the shapes for tensorflow. from_generator gives you the option to specify the output_signature. tf.py_function should do the same, ideally someone would fix that. But until someone does, this is how you do it.

For your specific case, this function could help:

@markdaoust hi, and thank you for your reply and ellaboration on this topic, though it’s somewhat still not perfectly clear to me. I mean, in load_image

def load_image(image_file):
    # 1) Decode the path
    image_file = image_file.decode('utf-8')

    # 2) Read the image
    img = cv2.imread(image_file)
    if len(img.shape) < 3:
        img = np.expand_dims(img, axis=-1)
    img = preprocessing_func(image=img)
    img = augment(img)
    img = tf.cast(img, tf.float32)
    img.set_shape(INPUT_IMAGE_SHAPE)
    # 3) Get the label
    label = tf.strings.split(image_file, "\\")[-1]
    label = tf.strings.substr(label, pos=0, len=1)
    label = tf.strings.to_number(label, out_type=tf.float32)
    label = tf.cast(label, tf.float32)
    label.set_shape([])
    return img, label

as shown, I do state the shapes of the images and the labels, but still it requires me to run the train_ds = train_ds.map(_fixup_shape), which basically does the same thing (or am I wrong?).

def _fixup_shape(images, labels):
    images.set_shape(INPUT_IMAGE_SHAPE)
    labels.set_shape([])
    return images, labels

I just can’t figure out why stating the shape of the images should be done twice? Again, thank you for your time.

Ah, yes that could use a little more explanation.

tf.data and tf.function (also used in keras) use TensorFlow’s 2-stage process. First it builds the graph, and then it executes the graph on your tensors.

Tensor.shape is the shape that was determined at graph build time (tf.shape(tensor) gets you the runtime shape).

So what’s happening here is that tf.numpy_function(load_image, ...) says “don’t build a graph for this part, just run it in python”. So none of the code in load_image runs during graph building, and TensorFlow doesn’t know you’re setting the shape in there.

Try putting only the non-tensorflow parts in your tf.numpy_function(load_image, ...). Move everything from img = preprocessing_func(image=img) down, into another function, then you should be able to ds.map() that function without the numpy_function wrapper.

But if you can use the tfio’s decode tiff, maybe you can drop the numpy_function entierly.

2 Likes

Thank you so much for your explanation