U-Net not working correctly

Good morning everyone,

I’m trying an U-Net based IA model to segmentate rat brains from MRI images. For that, I’m training the model with 574 192x192x1 MRI images (.tiff) and same dimension masks (.nii) that I manually segmentated using ImageJ. The code is based on this paper from January 2020: “Automatic Skull Stripping of Rat and Mouse Brain Data Using U-Net”. More specifically:

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Dropout, UpSampling2D, concatenate, BatchNormalization

def model4(input_size=(192, 192,1)):
inputs = Input(input_size)

conv1 = Conv2D(32, 3, activation='relu', padding='same')(inputs)
conv1 = BatchNormalization()(conv1)
conv1 = Conv2D(32, 3, activation='relu', padding='same')(conv1)
conv1 = BatchNormalization()(conv1)
pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

conv2 = Conv2D(64, 3, activation='relu', padding='same')(pool1)
conv2 = BatchNormalization()(conv2)
conv2 = Conv2D(64, 3, activation='relu', padding='same')(conv2)
conv2 = BatchNormalization()(conv2)
pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

conv3 = Conv2D(96, 3, activation='relu', padding='same')(pool2)
conv3 = BatchNormalization()(conv3)
conv3 = Conv2D(96, 3, activation='relu', padding='same')(conv3)
conv3 = BatchNormalization()(conv3)
pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

conv4 = Conv2D(128, 3, activation='relu', padding='same')(pool3)
conv4 = BatchNormalization()(conv4)    
conv4 = Conv2D(128, 3, activation='relu', padding='same')(conv4)
conv4 = BatchNormalization()(conv4)
pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

conv5 = Conv2D(256, 3, activation='relu', padding='same')(pool4)
conv5 = BatchNormalization()(conv5)
conv5 = Conv2D(256, 3, activation='relu', padding='same')(conv5)
conv5 = BatchNormalization()(conv5)
pool5 = MaxPooling2D(pool_size=(2, 2))(conv5)

conv6 = Conv2D(512, 3, activation='relu', padding='same')(pool5)
conv6 = BatchNormalization()(conv6)
conv6 = Conv2D(512, 3, activation='relu', padding='same')(conv6)
conv6 = BatchNormalization()(conv6)
pool6 = MaxPooling2D(pool_size=(2, 2))(conv6)

up1 = UpSampling2D(size=(2, 2))(conv6)
up1 = Conv2D(256, 3, activation='relu', padding='same')(up1)
up1 = BatchNormalization()(up1)
merge1 = concatenate([conv5, up1], axis=3)
conv7 = Conv2D(256, 3, activation='relu', padding='same')(merge1)
conv7 = BatchNormalization()(conv7)
conv7 = Conv2D(256, 3, activation='relu', padding='same')(conv7)
conv7 = BatchNormalization()(conv7)

up2 = UpSampling2D(size=(2, 2))(conv7)
up2 = Conv2D(128, 3, activation='relu', padding='same')(up2)
up2 = BatchNormalization()(up2)
merge2 = concatenate([conv4, up2], axis=3)
conv8 = Conv2D(128, 3, activation='relu', padding='same')(merge2)
conv8 = BatchNormalization()(conv8)
conv8 = Conv2D(128, 3, activation='relu', padding='same')(conv8)
conv8 = BatchNormalization()(conv8)

up3 = UpSampling2D(size=(2, 2))(conv8)
up3 = Conv2D(96, 3, activation='relu', padding='same')(up3)
up3 = BatchNormalization()(up3)
merge3 = concatenate([conv3, up3], axis=3)
conv9 = Conv2D(96, 3, activation='relu', padding='same')(merge3)
conv9 = BatchNormalization()(conv9)
conv9 = Conv2D(96, 3, activation='relu', padding='same')(conv9)
conv9 = BatchNormalization()(conv9)

up4 = UpSampling2D(size=(2, 2))(conv9)
up4 = Conv2D(64, 3, activation='relu', padding='same')(up4)
up4 = BatchNormalization()(up4)
merge4 = concatenate([conv2, up4], axis=3)
conv10 = Conv2D(64, 3, activation='relu', padding='same')(merge4)
conv10 = BatchNormalization()(conv10) 
conv10 = Conv2D(64, 3, activation='relu', padding='same')(conv10)
conv10 = BatchNormalization()(conv10) 

up5 = UpSampling2D(size=(2, 2))(conv10)
up5 = Conv2D(32, 3, activation='relu', padding='same')(up5)
up5 = BatchNormalization()(up5)
merge5 = concatenate([conv1, up5], axis=3)
conv11 = Conv2D(32, 3, activation='relu', padding='same')(merge5)
conv11 = BatchNormalization()(conv11) 
conv11 = Conv2D(32, 3, activation='relu', padding='same')(conv11)
conv11 = BatchNormalization()(conv11) 

outputs = Conv2D(1, 1, activation='sigmoid')(conv11)

model = Model(inputs=inputs, outputs=outputs)
return model

model4 = model4()

Futhermore, I normalized the MRI images using intensity and spacial normalization. The loss functions I’ve been aplying are Dice Loss, BCE and Tversky Focal Loss with different alpha/beta values.

The result is an entire white image, no masks. I don’t know what I’m doing wrong. I’m very new in all this, so probably I’m skipping something important.

By the way, the batch size and the learning rate are 16 and 0.001, respectively.

Thanks in advance and sorry for any grammatical mistakes.

Hi @Sara_Ortega_Espina & Welcome to the Tensorflow forum.
I think it would be useful if you provided with the full code, including correctly shaped tensors with random data (instead of your actual data), as well as the code you use to display the images spitted out by your model (your model could be right, issue would be at the time of displaying image?).
Thank you.