Cannot iterate over tensorflow dataset generator

Here is my code for reading in .npy files and their labels using generators.

def data_generator(file_paths):
  np.random.shuffle(file_paths)
  classes = tf.constant(["class 1","class 2","class 3"])

  for fn in file_paths:
    f = np.reshape(np.load(fn),(150,150,1))
    for j in range(len(classes)):
      if classes[j].numpy() in fn:
        yield f,np.array(j)

and this is how I created my dataset

    data_generator,args = (train_files),
    output_signature=(
        tf.TensorSpec(shape=(150,150,1),dtype=tf.float32),
        tf.TensorSpec(shape=(),dtype=tf.int64)
    )
)

However anytime I try to plot or view the generated dataset using next(iter(train_ds.take(5)))
I get the following error

InvalidArgumentError: {{function_node __wrapped__IteratorGetNext_output_types_2_device_/job:localhost/replica:0/task:0/device:CPU:0}} TypeError: data_generator() takes 1 positional argument but 30000 were given
Traceback (most recent call last):

  File "/usr/local/lib/python3.9/dist-packages/tensorflow/python/data/ops/dataset_ops.py", line 855, in get_iterator
    return self._iterators[iterator_id]

KeyError: 1


During handling of the above exception, another exception occurred:


Traceback (most recent call last):

  File "/usr/local/lib/python3.9/dist-packages/tensorflow/python/ops/script_ops.py", line 271, in __call__
    ret = func(*args)

  File "/usr/local/lib/python3.9/dist-packages/tensorflow/python/autograph/impl/api.py", line 642, in wrapper
    return func(*args, **kwargs)

  File "/usr/local/lib/python3.9/dist-packages/tensorflow/python/data/ops/dataset_ops.py", line 1039, in generator_py_func
    values = next(generator_state.get_iterator(iterator_id))

  File "/usr/local/lib/python3.9/dist-packages/tensorflow/python/data/ops/dataset_ops.py", line 857, in get_iterator
    iterator = iter(self._generator(*self._args.pop(iterator_id)))

TypeError: data_generator() takes 1 positional argument but 30000 were given


	 [[{{node PyFunc}}]] [Op:IteratorGetNext]

train_files is a list of paths. Any help will be greatly appreciated.

I figured it out now. The args parameter of the from_generator method should be a tuple which in my case I didn’t do. I also rewrote the generator function to make it clean and now everything works fine.

1 Like