I have trouble in distibuting the data across the gpus

When I try to distribute my data across the gpus and let them pass through the convolutional layers, it seems all things go pretty well.However, when it comes to the distribution to the second gpu, terminal gives me the feedback as follow, i don’t why the spatial dimension is 0 ?Actually, i just give it a satisfying format of input?
This is the error:

        in user code:
        
            File "MultiGPU-UnetStableDiffusion.py", line 202, in call  *
                self.unet_block = self._build_unet_block(dim_1, dim_2, self.width, self.height)
            File "MultiGPU-UnetStableDiffusion.py", line 106, in conv_block  *
                x = Conv2D(filters, kernel_size, strides=strides, padding=padding)(x)
            File "/root/miniconda3/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 67, in error_handler  **
                raise e.with_traceback(filtered_tb) from None
        
            ValueError: Exception encountered when calling layer "conv2d_36" (type Conv2D).
            
            `num_spatial_dims` must be 1, 2, or 3. Received: num_spatial_dims=0.
            
            Call arguments received by layer "conv2d_36" (type Conv2D):
              • inputs=tf.Tensor(shape=(None, 128, 128, 2752), dtype=float32)
        
        
        Call arguments received by layer "u_net_diffusion_module" (type UNetDiffusionModule):
          • noisy_images=tf.Tensor(shape=(4, 512), dtype=float32)
          • time_step=tf.Tensor(shape=(128,), dtype=float32)
          • text_embeddings=tf.Tensor(shape=(4, 33, 512), dtype=float32)
    
    
    Call arguments received by layer "text2_image_diffusion_model" (type Text2ImageDiffusionModel):
      • text_inputs=tf.Tensor(shape=(4, 33), dtype=int32)
      • image_inputs=tf.Tensor(shape=(4, 128, 128, 3), dtype=float32)
      • time_steps=tf.Tensor(shape=(128,), dtype=float32)
    and this is my model's structure :
class TextEncoder(tf.keras.Model):
    def __init__(self, vocab_size, output_dim=512, embed_dim=512):
        super(TextEncoder, self).__init__()
        self.embedding = Embedding(input_dim=vocab_size, output_dim=embed_dim)
        self.text_projection = Dense(output_dim, activation='relu')


    def call(self, input_ids):
        outputs = self.embedding(input_ids)
        text_embeddings = self.text_projection(outputs)
        return text_embeddings


class ImageEncoder(tf.keras.Model):
    def __init__(self, input_shape=(WIDTH, HEIGHT, 3), output_dim=512):
        super(ImageEncoder, self).__init__()
        self.conv_blocks = [
            layers.Conv2D(64, (3, 3), activation='relu', padding='same', input_shape=input_shape),
            layers.MaxPooling2D((2, 2)),
            layers.Conv2D(128, (3, 3), activation='relu', padding='same'),
            layers.MaxPooling2D((2, 2)),
            layers.Conv2D(256, (4, 4), activation='relu', padding='same'),
            layers.MaxPooling2D((2, 2)),
            layers.Conv2D(512, (4, 4), activation='relu', padding='same'),
            layers.MaxPooling2D((2, 2)),
            layers.Conv2D(1024, (4, 4), activation='relu', padding='same'),
            layers.MaxPooling2D((2, 2)),
        ]
        self.flatten = layers.Flatten()
        self.image_projection = layers.Dense(output_dim)

    def call(self, inputs):
        x = inputs
        for layer in self.conv_blocks:
            x = layer(x)
        x = self.flatten(x)
        latent_representation = self.image_projection(x)
        return latent_representation




class UNetDiffusionModule(tf.keras.Model):
    def __init__(self, num_batch, width, height, time_embedding_dim=128, text_embedding_dim=64):
        super(UNetDiffusionModule, self).__init__()
        self.batch_size = num_batch 
        self.width = width
        self.height = height          
        self.time_embedding = Embedding(input_dim=time_embedding_dim, output_dim=self.batch_size)
        self.text_projection = Dense(units=text_embedding_dim)
        self.time_embedding_dim = time_embedding_dim
        self.text_embedding_dim = text_embedding_dim       

            
        self.final_conv = Conv2D(filters=3, kernel_size=(7, 7), strides=(1, 1), padding='same')
    
            
    def _build_unet_block(self, dim, dim2, width, height):
        inputs = Input(shape=(width, height, dim2 + dim * self.text_embedding_dim + self.time_embedding_dim))

        def conv_block(x, filters, kernel_size, strides=1, padding='same', activation='relu'):
            x = Conv2D(filters, kernel_size, strides=strides, padding=padding)(x)
            x = BatchNormalization()(x)
            x = Activation(activation)(x)
            return x

        def deconv_block(x, filters, kernel_size, strides=2, padding='same', activation='relu'):
            x = Conv2DTranspose(filters, kernel_size, strides=strides, padding=padding)(x)
            x = BatchNormalization()(x)
            x = Activation(activation)(x)
            return x

        def residual_block(x, filters, kernel_size=3, strides=1, padding='same', activation='relu'):
            x1 = Conv2D(filters, kernel_size, strides=strides, padding=padding)(x)
            x1 = BatchNormalization()(x1)
            x1 = Activation(activation)(x1)
            x1 = Conv2D(filters, kernel_size, strides=1, padding='same')(x1)
            x1 = BatchNormalization()(x1)
            x = Add()([x, x1])
            x = Activation(activation)(x)
            return x

        conv1 = conv_block(inputs, 64, 4, 2)  
        conv2 = conv_block(conv1, 128, 4, 2)  
        conv3 = conv_block(conv2, 256, 4, 2)  

            
        conv3 = residual_block(conv3, 256)
        conv3 = residual_block(conv3, 256)

        conv4 = conv_block(conv3, 512, 4, 2)  
        conv5 = conv_block(conv4, 512, 4, 2)  
        conv6 = conv_block(conv5, 512, 4, 2)  
        conv7 = conv_block(conv6, 512, 4, 2)  

        deconv8 = deconv_block(conv7, 512, 4, 2)  
        deconv8 = Concatenate()([deconv8, conv6])

    
        deconv8 = residual_block(deconv8, 1024)
        deconv8 = residual_block(deconv8, 1024)

        deconv9 = deconv_block(deconv8, 512, 4, 2)  
        deconv9 = Concatenate()([deconv9, conv5])

        
        deconv9 = residual_block(deconv9, 1024)
        deconv9 = residual_block(deconv9, 1024)

        deconv10 = deconv_block(deconv9, 512, 4, 2)  
        deconv10 = Concatenate()([deconv10, conv4])


        deconv10 = residual_block(deconv10, 1024)
        deconv10 = residual_block(deconv10, 1024)

        deconv11 = deconv_block(deconv10, 256, 4, 2) 
        deconv11 = Concatenate()([deconv11, conv3])

        
        deconv11 = residual_block(deconv11, 512)
        deconv11 = residual_block(deconv11, 512)

        deconv12 = deconv_block(deconv11, 128, 4, 2)  
        deconv12 = Concatenate()([deconv12, conv2])

        
        deconv12 = residual_block(deconv12, 256)
        deconv12 = residual_block(deconv12, 256)

        deconv13 = deconv_block(deconv12, 64, 4, 2)  
        deconv13 = Concatenate()([deconv13, conv1])

        
        deconv13 = residual_block(deconv13, 128)
        deconv13 = residual_block(deconv13, 128)

        deconv14 = deconv_block(deconv13, 64, 4, 2)

        outputs = Conv2D(3, kernel_size=7, strides=1, padding='same', activation='sigmoid')(deconv14)
        generator = Model(inputs=inputs, outputs=outputs)

        return generator

            
    def call(self, noisy_images, time_step, text_embeddings):
            
        time_embedding = self.time_embedding(time_step)
        text_embedding = self.text_projection(text_embeddings)
        dim_1 = tf.TensorShape(text_embedding.shape).as_list()[1]
        dim_2 = tf.TensorShape(noisy_images.shape).as_list()[-1]
        dim_3 = tf.TensorShape(text_embedding.shape).as_list()[0]

        self.unet_block = self._build_unet_block(dim_1, dim_2, self.width, self.height)
            
        time_embedding_reshaped = tf.reshape(time_embedding, [self.batch_size, 1, 1, self.time_embedding_dim])
            
        time_embedding_tiled = tf.tile(time_embedding_reshaped, [1, self.width, self.height, 1]) 

        noisy_images_reshaped = tf.reshape(noisy_images, [self.batch_size, 1, 1, dim_2])
        noisy_images_tiled = tf.tile(noisy_images_reshaped, [1, self.width, self.height, 1])            
            
            
        text_embedding_reshaped = tf.reshape(text_embedding, [self.batch_size, 1, 1, self.text_embedding_dim*dim_1])
        text_embedding_tiled = tf.tile(text_embedding_reshaped, [1, self.width, self.height, 1]) 
        
            
        d1 = Concatenate(axis=-1)([noisy_images_tiled, time_embedding_tiled, text_embedding_tiled]) 
            
        d1 = self.unet_block(d1)
        denoised_images = self.final_conv(d1)
        return denoised_images

    
class Text2ImageDiffusionModel(tf.keras.Model):
    def __init__(self, vocab_size, num_batch, width, height):
        super(Text2ImageDiffusionModel, self).__init__()
        self.text_encoder = TextEncoder(vocab_size)
        self.image_encoder = ImageEncoder()
        self.batch_size = num_batch
        self.width = width
        self.height = height
        self.diffusion_module = UNetDiffusionModule(self.batch_size, self.width, self.height)
        
    def call(self, text_inputs, image_inputs, time_steps):
        text_embeddings = self.text_encoder(text_inputs)
        latent_images = self.image_encoder(image_inputs)
        generated_images = self.diffusion_module(latent_images, time_steps, text_embeddings)
        return generated_images
        

and this is the calling function

def main():
    configuration()
    epochs = 10000
    implemented_coefficient = 0.37
    time_embedding_dim = 128
    csv_path = 'descriptions.csv'
    images_path = './images'

    strategy = tf.distribute.MirroredStrategy()
    print(f'Number of available GPUs: {strategy.num_replicas_in_sync}')


    def max_len(vectors) :
        length = max(len(vec) for vec in vectors)
        return length

    with strategy.scope():
        dataset, vocab_size, magnitude = load_dataset(csv_path, images_path, BATCH_SIZE, height, width)
        gross_magnitude = max_len(magnitude)

        text2image_model = Text2ImageDiffusionModel(vocab_size, BATCH_SIZE, width, height)
        optimizer = Adam(learning_rate=0.001)
        loss_fn = MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)

        text2image_model.compile(optimizer=optimizer, loss=loss_fn)

    log_file_path = './log/UnetSD.log'
    save_path = './samples'
    save_interval = 50

    def data_augmentation(images, noise_factor=0.1):
        noise = tf.random.normal(shape=(BATCH_SIZE, width, height, 3), mean=0.0, stddev=noise_factor)
        return images + noise

    @tf.function
    def train_step(batch):
        with tf.GradientTape() as tape:
            image_inputs, text_inputs = batch[0], batch[1]
            time_steps = tf.range(0, time_embedding_dim, dtype=tf.float32)
            noised_inputs = data_augmentation(image_inputs, implemented_coefficient)
            print(text_inputs.shape, noised_inputs.shape, time_steps.shape)
            generated_images = text2image_model(text_inputs, noised_inputs, time_steps)
            loss = loss_fn(image_inputs, generated_images)
            scaled_loss = tf.reduce_sum(loss) * (1. / (BATCH_SIZE * strategy.num_replicas_in_sync))
        gradients = tape.gradient(scaled_loss, text2image_model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, text2image_model.trainable_variables))
        return scaled_loss

    for epoch in range(epochs):
        dataset = dataset.shuffle(buffer_size=vocab_size, reshuffle_each_iteration=True)
        iterator = iter(dataset)

        num, total_losses = 0, 0
        for num_, batch in enumerate(iterator):
            per_loss = strategy.run(train_step, args=(batch,))
            num += 1
            total_losses += per_loss
            print(f'per_batch_loss:{per_loss} epoch:{epoch} batch_index:{num_+1}')
        train_loss = total_losses / num

        print(f'Epoch {epoch + 1}/{epochs}, Loss: {train_loss.numpy()}')
        with open(log_file_path, 'a') as log_file:
            log_file.write(f"Epoch {epoch + 1}, Batch Losses: {train_loss.numpy()}\n")

i am really appreciate if you guys can help me fix that ! :grinning: :grinning: :grinning:

1 Like