How to set Dataset Cardinality of a dataset created using tf.data.Dataset.from_generator()?

I am currently working on a project which uses huggingface. I created the huggingface datasets and converted it to tensorflow. The method of conversion is not from_tensor_slices(), the one shown in their documentation but using from_generator(). I found this method a lot faster but at the time of training using TFTrainer(), I encounter an error:

ValueError: The training dataset must have an asserted cardinality

I checked and found the reason was from_generator(). Inorder to verify this, I created a very basic dataset using from_generator() method and checked its cardinality:

dumm_ds = tf.data.Dataset.from_generator(lambda: [tf.constant(1)]*1000, output_signature=tf.TensorSpec(shape=[None], dtype=tf.int64))
tf.data.experimental.cardinality(dumm_ds)

Output:
<tf.Tensor: shape=(), dtype=int64, numpy=-2>
where, ‘-2’ mean UNKNOWN_CARDINALITY.

I would like to know whether this is a bug or not? and If not then, how can I change the cardinality?

Check

https://tensorflow-prod.ospodiscourse.com/t/typeerror-dataset-length-is-unknown-tensorflow/948/3?u=bhack

1 Like

The link does not work anymore. Can you please point to the solution?

@neuron, Please refer to the link TypeError: dataset length is unknown tensorflow - #3 by Bhack. Thank you.

Okay I tested, here is short snippet for others to try.
Good thing is that it can raise error if we accidentally make generator that may yield more elements

import tensorflow as tf


def g():
    while True:
        yield 33


ds = tf.data.Dataset.from_generator(
    g, output_signature=tf.TensorSpec(shape=[], dtype=tf.int32))
print("ds.cardinality()", ds.cardinality())

ds1 = ds.take(100).apply(tf.data.experimental.assert_cardinality(100))
print("ds1.cardinality()", ds1.cardinality())
# this will work fine
for _ in ds1:
    ...

ds2 = ds.take(150).apply(tf.data.experimental.assert_cardinality(100))
print("ds2.cardinality()", ds1.cardinality())
# this will raise error as the ds2 will yield more than 100 elements
for _ in ds2:
    ...