Image upsampling using interpolation function

TensorfFlow now provides a function for trilinear interpolation. However, how can we use it to upsample a 3D image ? For example, I have an image x of shape (W,H,D) and I would like to upsample it to (W*2,H*2,D*2) .

Hi @vRoca, I don’t have any idea how to upsample a 3D data using tfg, but using tensorflow you can achieve it by using tf.keras.layers.UpSampling3D. For example

input_shape = (2, 1, 2, 1, 3)
x = tf.constant(1, shape=input_shape)
y = tf.keras.layers.UpSampling3D(size=2)(x)
#output: y.shape=(2, 2, 4, 2, 3)

Thank You.

Hi @Kiran_Sai_Ramineni , yes I know this class but I search for upsampling using trilinear interpolation.

To upsample a 3D image using trilinear interpolation in TensorFlow, you can use the tf.raw_ops.ResizeBilinear function. This function performs bilinear interpolation in 2D and can be extended to 3D by applying it separately along each dimension.

import tensorflow as tf

def trilinear_interpolation_3d(image, new_shape):
    # Get the current shape of the image
    current_shape = tf.shape(image)
    
    # Calculate the scaling factors for each dimension
    scale_factors = tf.cast(new_shape, tf.float32) / tf.cast(current_shape, tf.float32)
    
    # Reshape the image to a 4D tensor (batch_size, height, width, depth, channels)
    image = tf.expand_dims(image, axis=0)
    
    # Perform trilinear interpolation along each dimension
    resized_image = image
    for dim in range(3):
        resized_image = tf.raw_ops.ResizeBilinear(
            images=resized_image,
            size=tf.cast(tf.round(current_shape[dim] * scale_factors[dim]), tf.int32),
            align_corners=True if dim < 2 else False
        )
    
    # Remove the batch dimension and return the upsampled image
    return tf.squeeze(resized_image, axis=0)

# Example usage
# Assuming x is your 3D image tensor of shape (W, H, D)
x = tf.placeholder(tf.float32, shape=(None, None, None))

# Upsample x to twice its size in each dimension
upsampled_x = trilinear_interpolation_3d(x, (tf.shape(x)[0] * 2, tf.shape(x)[1] * 2, tf.shape(x)[2] * 2))

I remember using this technique, but I am not sure if it will work for you.