Generating a dataset of .npy file using tf.data API

Hi, I am trying to create a custom dataset using tf.data API, the dataset consists of a 3-Dimensional NumPy file of shape (128, 128, 128, 3). I am trying the following code to create a dataset …

def parse_image(image_path: str) -> dict:
    image = tf.io.read_file(image_path)
    image = tf.io.decode_raw(image, out_type=tf.float32)
    
    return {'image': image}

images = glob.glob(IMAGE_PATH)
train_ds = tf.data.Dataset.list_files(images)
train_ds = train_ds.map(parse_image, num_parallel_calls=24)
train_ds = train_ds.batch(4)

But the shape I’m getting in output is (2, 12582944) which can’t be converted to (128, 128, 128, 3). what Am I doing wrong? what can be done to solve this ?

Note: there is no chance of loading the dataset in NumPy first and then converting it to tf.data Dataset as there are around 1000 images that can’t be saved simultaneously to the memory at a single instance.

Maybe using tf.image.decode_png/decode_jpg instead of decode_raw will help.

It won’t work as the images are stored in .npy format.