Tensorflow Dataset and NetCDF (tf.py_function)

Hi, everyone.

I am running into some issues with not being to query the length of a dataset. I need this length, post-batch, because this is how I set the schedule for my ExponentialDecay learning schedule.

In a previous workflow, I did something like this…

def create_full_dataset(X_hi, X_low, Y_hi, X_hi_name='input_hi', X_low_name='input_low', activation='sigmoid', training=False):
    activation = activation.lower()
    ACTIVATIONS = ['sigmoid', 'tanh']
    if activation not in ACTIVATIONS: raise Exception("Activation function unknown. Options: sigmoid, tanh.")

    if training:
        rnFlip = tf.random.uniform(shape=(), maxval=1.0, dtype = tf.float64)
        rnRot  = tf.random.uniform(shape=(), maxval=4, dtype=tf.int32)
        randoms = [rnFlip, rnRot]
    else:
        randoms = [None, None]

    X_hi = X_hi.map(lambda x: decode_file(x, randoms=randoms, activation = activation), num_parallel_calls = tf.data.experimental.AUTOTUNE)
    X_low = X_low.map(lambda x: decode_file(x, randoms=randoms, activation = activation), num_parallel_calls = tf.data.experimental.AUTOTUNE)
    Y_hi = Y_hi.map(lambda x: decode_file(x, randoms=randoms, activation = activation), num_parallel_calls = tf.data.experimental.AUTOTUNE)

    ds_X = tf.data.Dataset.zip( (X_hi, X_low) ).map(lambda X_hi, X_low: {X_hi_name : X_hi, X_low_name : X_low } )
    return tf.data.Dataset.zip( (ds_X, Y_hi) )

...

    ds_train_hi = tf.data.Dataset.list_files(str(trainingHiRes / config['training']['glob']['training_hires']), \
                                             shuffle=False) # Definitely do NOT shuffle these datasets at read-in time
    ds_train_lo = tf.data.Dataset.list_files(str(trainingLoRes / config['training']['glob']['training_lowres']), \
                                             shuffle=False)
    ds_target_tr = tf.data.Dataset.list_files(str(trainingTarget / config['training']['glob']['truth_train']), shuffle=False)

    ds_train = create_full_dataset(ds_train_hi, ds_train_lo, ds_target_tr,
                                   activation = config['training']['model']['activation_final'],
                                   training=config['training']['dataset']['augment'])

All of the input files were .png files. Later on, I was able to do this:

    if final_lr != None:
        lr_decay_factor = (final_lr / initial_lr) ** (1.0 / epochs)
        if steps_per_epoch == None: steps_per_epoch = len(ds_train) # Already batched
        lr_schedule = K.optimizers.schedules.ExponentialDecay(
                initial_learning_rate = initial_lr,
                decay_steps = steps_per_epoch,
                decay_rate = lr_decay_factor)
    else:
        lr_schedule = initial_lr

That worked great! However, my new dataset is in NetCDF format, so I had to roll my own unpack code. I now do something like this:

def create_dataset(path, glob, cfg, parallel = None):
    output = tf.data.Dataset.list_files(str(path / glob), shuffle = True, seed = 42).map(
            lambda x: tf.py_function(unpack_netcdf_xarray, [x, cfg['training']['dataset']['prior'],
                                                     cfg['training']['dataset']['model_vars']], Tout = [tf.float32, tf.float32, tf.float32, tf.float32]) ,
            num_parallel_calls = tf.data.AUTOTUNE, deterministic=False)
    return tf.data.Dataset.zip( (output.flat_map(lambda p, s, sc, y: extract_X(p, s, sc, y)),
                                 output.flat_map(lambda p, s, sc, y: extract_Y(p, s, sc, y))) )

...
    ds = create_dataset(path = trainDataPath,
                        glob = globString,
                        cfg = config,
                        parallel = None)
    ds = ds.batch(batchSize).prefetch(tf.data.AUTOTUNE)

I can do a take() command and see that the data did indeed get populated properly, it’s batched properly, and prefetch does…whatever it does under the hood. But upon issuing len(ds) as before, I get the dreaded raise TypeError("The dataset length is unknown.") Removing prefetch doesn’t seem to matter.

Is this a known issue? A bug? Did I do something wrong?