Problem loading image patches using pix2pix as a modle

Hello there I would like to load image patches in my tensorflow model. I followed pix2pix loading example and adapted it to my model. But I cannot load the patches. Thanks for the assistance. Here is my code:

def load2(image_file):
image_ts =
image_ts = tf.image.decode_png(image_ts)

# Calculate the center of the image
w_center = 358

# Split the image into two parts
input_image = image_ts[:, w_center:, :]
real_image = image_ts[:, :w_center, :]

input_image = tf.cast(input_image, tf.float32)
real_image = tf.cast(real_image, tf.float32)

input_image, real_image = resize(input_image, real_image, IMG_HEIGHT, IMG_WIDTH)

input_patches = []
gt_patches = []

# Extract patches using tf.image.extract_patches
input_patches = tf.image.extract_patches(input_image[tf.newaxis, ...],sizes=[1, patch_size[0], patch_size[1], 1],
strides=[1, stride[0], stride[1], 1],rates = [1, 1, 1, 1])

gt_patches = tf.image.extract_patches(real_image[tf.newaxis, ...],sizes=[1, patch_size[0], patch_size[1], 1],
strides=[1, stride[0], stride[1], 1],rates = [1, 1, 1, 1])

# Reshape patches to get a list of individual patches

input_patches = tf.reshape(input_patches, [-1, patch_size[0], patch_size[1], 3])

gt_patches = tf.reshape(gt_patches, [-1, patch_size[0], patch_size[1], 3])

return input_patches, gt_patches, image_file


def load_image_train(image_file):
input_image, real_image, image_file = load2(image_file)

input_image, real_image = random_jitter(input_image, real_image)

input_image, real_image = normalize_images(input_image, real_image)

return input_image, real_image, image_file

input_directory_train = “/home/rafael/Área de Trabalho/Dataset-PNG”
input_directory_test = “/home/rafael/Área de Trabalho/Dataset-PNG/Test”
input_directory_val = “/home/rafael/Área de Trabalho/Dataset-PNG/Val”

Use glob to get a list of image file paths in the input directory

image_files = glob.glob(input_directory_train + “/.png") # Change the file extension as needed
image_files_test = glob.glob(input_directory_test + "/
.png”) # Change the file extension as needed
image_files_val = glob.glob(input_directory_val + “/*.png”) # Change the file extension as needed

Create a dataset from the list of image files

dataset =
dataset =,

Define batch size and other data pipeline operations (e.g., shuffle, repeat, etc.) batch_size = 32

dataset = dataset.shuffle(buffer_size=1000) # Shuffle the data
dataset = dataset.batch(BATCH_SIZE) # Batch the data
dataset = dataset.prefetch( # Prefetch for better performanc

Skip the first N elements and take the next M elements

N = 0
M = 200
dataset = dataset.skip(N).take(M)

Hi @Rafael_Scatena & welcome to the Tensorflow forum.
Can you please be more specific about the issue you are facing? Can you also provide with details about your dataset and any other piece of relevant information? What are message do you get?
Thank you.

I get the message that the input shape is (1,0,16,16,1) instead of (1,16,16,1). It should be (1,16,16,1).

Hi @Rafael_Scatena, From the above error i can say that the error is due to the passing of the incorrect input shape. It may happen during the processing of the data.

Could you please share the stand alone code in colab gist to reproduce the issue. Also if possible please share the sample data so that it would be easy to reproduce the issue and find out the cause of the error. Thank You.

1 Like

It looks like you’re trying to load and process image patches for a TensorFlow model. I see a few issues in your code. Let’s address them:

  1. Imports: Make sure you have the necessary imports at the beginning of your code, such as importing TensorFlow and any other required libraries.

  2. Function Definitions: You provided code snippets, but it’s not clear if you have defined the functions like resize, random_jitter, and normalize_images. Ensure that these functions are correctly defined.

  3. String Quotes: In your code, you are using smart quotes (e.g., “ and ”) instead of regular quotes ("). Please replace them with regular double quotes.

  4. Incorrect File Extensions: You are using incorrect file extensions for image file filtering. The correct extensions should be “.png,” not “/.png” or "*.png”.

  5. load_image_train Function: This function seems to return multiple values, but your code doesn’t capture those returned values correctly. Make sure to unpack the values returned by the load2 function correctly.

  6. dataset Configuration: Ensure that you have defined variables like IMG_HEIGHT, IMG_WIDTH, patch_size, and stride with appropriate values. Also, the variable BATCH_SIZE should be defined.

  7. File Paths: Verify that the file paths in input_directory_train, input_directory_test, and input_directory_val are correct.

  8. File Loading: The code uses glob to get a list of image file paths, but you should ensure that the file paths are correctly obtained, and the file extensions match the actual extensions in your dataset directory.

  9. Data Pipeline Operations: You have defined data pipeline operations like shuffle, batch, and prefetch correctly. However, make sure that N and M are set to appropriate values.

Once you address these issues cursed wikihow images and provide additional context or error messages if you encounter any problems, I can help you further with specific questions or issues you’re facing.

gist code: MTANN3.ipynb · GitHub

The dataset with .png is this one: Dataset-PNG - Google Drive