Training start up when calling fit() extremly slow for model

Hi, a follow up to my other threads:
I am benchmarking my new implementation in comparison to my old implementation of my recurrent autoencoder. However, for my new implementation, for some reason the initial startup before running the first epoch of the training is very slow. Subsequent epochs seem fine (and are faster than for my old implementation. For the old implementation, the startup is normal and does not take long (training starts almost instantly) with the otherwise identical code.
Here is what I do

import tensorflow as tf
import numpy as np
import time



class FRAEOld(tf.keras.Model):
    def __init__(self, latent_dim, shape, ht, n1, n2, n3, n4, batch_size=1, bypass=True, trainable=True, **kwargs):
        super(FRAEOld, self).__init__(**kwargs)
        self.latent_dim = latent_dim
        self.shape = shape
        self.ht = ht
        self.batch_size = batch_size
        # self.buffer = tf.Variable(initial_value=tf.zeros(shape=(batch_size,shape[0] * self.ht), dtype=tf.float32), trainable=False)

        self.SetupBuffer(batch_size, shape[0], ht)
        self.bypass = bypass
        self.quantizer = None
        self.trainable = trainable

        self.l1 = tf.keras.layers.Dense(n1, activation='tanh', input_shape=shape)
        self.l2 = tf.keras.layers.Dense(n2, activation='tanh')  # has to be fixed to n2
        self.ls = tf.keras.layers.Dense(latent_dim, activation='swish')

        self.l3 = tf.keras.layers.Dense(n3, activation='tanh')
        self.l4 = tf.keras.layers.Dense(n4, activation='tanh')
        self.l5 = tf.keras.layers.Dense(shape[-1], activation='linear')

        self.save_encoded = False
        self.encoded_list = []

    def SetupBuffer(self, batch_size, input_dim, ht):
        self.buffer = tf.Variable(initial_value=tf.zeros(shape=(batch_size, input_dim * ht), dtype=tf.float32),
                                  trainable=False)
        self.batch_size = batch_size

    def get_config(self):
        config = super(FRAE, self).get_config().copy()
        config.update({'latent_dim': self.latent_dim, 'bypass': self.bypass, 'quantizer': self.quantizer,
                       "l1": self.l1, "l2": self.l2, "ls": self.ls, "l3": self.l3, "l4": self.l4, "l5": self.l5,
                       "ht": self.ht, "batch_size": self.batch_size, "shape": self.shape, "name": self.name})

        return config

    # @tf.function(experimental_compile=True)
    def resetBuffer(self):
        self.buffer[:, :].assign(tf.zeros(shape=(1, self.shape[0] * self.ht), dtype=tf.float32))

    def quantize(self, x):
        x_np = x.numpy().astype(np.float32)
        return self.quantizer.cluster_centers_[self.quantizer.predict(x_np), :]


    def call(self, x):
        if self.bypass:
            return x
        decoded = tf.TensorArray(tf.float32, size=tf.shape(x)[1])
        for i in tf.range(tf.shape(x)[1]):
            xslice = x[:, i, :]
            xin = tf.concat((xslice, self.buffer), axis=1)
            encoded = self.ls(self.l2(self.l1(xin)))

            if self.save_encoded:
                self.encoded_list.append(encoded.numpy())

            if self.quantizer is not None:
                encoded = tf.py_function(self.quantize, [encoded], tf.float32)

            decin = tf.concat([encoded, self.buffer], axis=1)
            y = self.l5(self.l4(self.l3(decin)))
            decoded = decoded.write(i, y)
            i += 1
            self.buffer.assign(tf.concat([y, self.buffer[:, :-self.shape[0]]], axis=1))

        tmp = tf.transpose(decoded.stack(), [1, 0, 2])
        return tmp


class VectorQuantization(tf.keras.layers.Layer):
    def __init__(self, codebook = None, record_data = False, **kwargs):
        super(VectorQuantization, self).__init__(**kwargs)
        self.codebook = codebook
        self.record_data = record_data

    def SetCodebook(self, codebook):
        self.codebook = codebook

 

    @tf.custom_gradient
    def call(self, inputs):
        def grad(dy):
            return 1

        if self.codebook:
            # Flatten input tensor
            input_shape = tf.shape(inputs)
            flat_inputs = tf.reshape(inputs, [-1, self.embedding_dim])
            # Reshape codebook for distance calculation
            reshaped_codebook = np.expand_dims(self.codebook, axis=0)
            # Calculate distances between inputs and codebook vectors
            distances = tf.reduce_sum(tf.square(tf.expand_dims(flat_inputs, axis=1) - reshaped_codebook), axis=2)

            # Find the index of the closest centroid for each input
            embedding_indices = tf.argmin(distances, axis=1)

            # Gather closest embeddings from codebook
            quantized = tf.gather(tf.convert_to_tensor(self.codebook), embedding_indices)

            # Reshape quantized tensor to match input shape
            quantized = tf.reshape(quantized, input_shape)


            return quantized, grad
        else:
            return inputs, grad


class FRAEEncoder(tf.keras.layers.Layer):
    def __init__(self, input_shape, latent_dim,   layer_config, **kwargs):
        super(FRAEEncoder, self).__init__(**kwargs)
        self.SetupLayers(input_shape, latent_dim, layer_config)

    def SetupLayers(self, input_shape, latent_dim, layer_config):
        self.encoder = []
        activations  = layer_config["activations"]
        num_neuron   = layer_config["neurons"]
        for i, act in enumerate(activations):
            if i == 0:
                self.encoder.append(tf.keras.layers.Dense(num_neuron[i], activation=act, input_shape=(input_shape,)))
            else:
                self.encoder.append(tf.keras.layers.Dense(num_neuron[i], activation=act))

        if len(activations) > 0:
            self.encoder.append(tf.keras.layers.Dense(latent_dim, "swish"))
        else:
            self.encoder.append(tf.keras.layers.Dense(latent_dim, "swish", input_shape=(input_shape,)))



    def call(self, inputs, state):
        # encoder
        encoded = tf.concat([inputs, state], axis=-1)

        for lrs in self.encoder:
            encoded = lrs(encoded)

        new_output = encoded
        return new_output

class FRAEDecoder(tf.keras.layers.Layer):
    def __init__(self,   output_dim, layer_config, **kwargs):
        super(FRAEDecoder, self).__init__(**kwargs)
        self.SetupLayers(output_dim, layer_config)

    def SetupLayers(self,  output_dim, layer_config):
        self.decoder = []
        activations  = layer_config["activations"]
        num_neuron   = layer_config["neurons"]

        for i, act in reversed(list(enumerate(activations))):
            self.decoder.append(tf.keras.layers.Dense(num_neuron[i],activation=act))

        self.decoder.append(tf.keras.layers.Dense(output_dim, activation='linear'))


    def call(self, encoded, state):
        y = tf.concat([encoded, state], axis=1)
        for lrs in self.decoder:
            y = lrs(y)

        #update output and state
        new_output = y  # Result of some operations on `combined`
        new_state  = tf.concat([new_output, state[:, :-tf.shape(new_output)[-1]]], axis=-1)
        return new_output, new_state




class FRAE(tf.keras.Model):
    def __init__(self,  output_dim, latent_dim, ht,  layer_config={"activations":[],"neurons":[]} , **kwargs):
        super(FRAE, self).__init__(**kwargs)
        self.output_dim = output_dim
        self.ht = ht
        self.encoder = FRAEEncoder(output_dim, latent_dim, layer_config, name=self.name+"_Encoder")
        self.decoder = FRAEDecoder(output_dim, layer_config, name=self.name+"_Decoder")
        self.quantizer = VectorQuantization(name=self.name+"_VQ")


    def SetQuantizer(self, codebook):
        self.quantizer.SetQuantizer(codebook)

    def call(self, x, return_quantized=False):
        state = tf.zeros(shape=(tf.shape(x)[0], tf.shape(x)[2] * self.ht))

        x_unstacked = tf.unstack(x, axis=1)

        output_list = [None] * len(x_unstacked)
        for i, inputs in enumerate(x_unstacked):
            encoded = self.encoder(inputs, state)
            encoded_q = self.quantizer(encoded)
            decoded, state = self.decoder(encoded_q, state)
            output_list[i] = decoded

        outputs = tf.stack(output_list, axis=1)
        if return_quantized:
            return outputs, encoded_q
        else:
            return outputs


if __name__ == '__main__':
    output_dim = 8
    latent_dim = 6


    ht = 1
    with tf.device('/cpu'):
        initdata = np.random.rand(256,1,output_dim)
        data = np.random.rand(256*10, 3000, output_dim)
        layer_config = {'activations':['tanh','tanh'], 'neurons':[20,20]}
        frae = FRAE(output_dim, latent_dim, ht, layer_config)
        fraeold = FRAEOld(latent_dim, (output_dim,),ht,  20,20,20,20,256,False)
        frae(initdata)
        fraeold(initdata)
        frae.summary()
        fraeold.summary()
        loss = tf.keras.losses.mse

        frae.compile(optimizer='adam', loss=loss, run_eagerly=False)
        frae.fit(data, data,  epochs=10, batch_size=256, verbose=1)

    print("Done")

This takes ages to start the first epoch (about 15 minutes). If I replace “frae.fit(…)” with “fraeold.fit(…)” training starts virtually instantly. So the data size cannot be the sole reason for the slow down.
Does anyone know what is the reason and how to solve it? It is really annoying albeit it is not too bad if I train for many epochs.

For comparison (and more data than posted above):
With my old model, the first epoch takes 42 seconds, the other epochs about 38 seconds.
With my new model, the first epoch takes 2800 seconds (sic!), the other epochs about 31 seconds.

Does anyone have an idea?

The issue you’re experiencing with the slow startup time before the first epoch of your new recurrent autoencoder implementation could be due to several reasons. Let’s explore some potential causes and solutions:

  1. Model Initialization and Compilation Overhead: The new model structure may introduce a significant overhead during initialization and compilation. TensorFlow has to construct a computational graph for your model, which can be time-consuming, especially for complex models with custom layers and operations.
  2. Data Pipeline: The way data is fed into the model might be causing the delay. Ensure that the data pipeline is optimized and not the bottleneck. TensorFlow’s tf.data API provides various tools for efficient data loading and preprocessing.
  3. Custom Layers and Operations: Your new model implementation uses custom layers and operations, such as VectorQuantization and dynamic operations within call methods. These custom components can add overhead, especially if they involve operations that are not fully optimized or are executed eagerly.
  4. Eager Execution vs. Graph Execution: Ensure that TensorFlow is running in graph mode (which is the default in TensorFlow 2.x when using fit, unless run_eagerly=True is set). Eager execution is beneficial for debugging but can lead to slower execution times compared to graph execution.
  5. Device Placement: You’re explicitly placing the model on the CPU with tf.device('/cpu'). Depending on your system’s configuration, initializing the model on a GPU (if available) might be faster due to parallelism, even though subsequent epochs are not heavily affected.

Solutions:

  • Model Simplification: Review your model’s architecture to see if there are any redundancies or overly complex operations that could be simplified without sacrificing the model’s performance.
  • Optimize Data Pipeline: Ensure that your data pipeline is efficient. Utilize tf.data API features such as prefetch, cache, and interleave to optimize data loading and preprocessing.
  • Graph Execution: Make sure the model is running in graph mode during training. Remove or minimize any operations that force TensorFlow to revert to eager execution. This includes removing unnecessary tf.py_function calls and ensuring that custom layers are graph-compatible.
  • Batching Operations: Whenever possible, batch operations outside of loops. For example, your model processes data in a loop over tf.range(tf.shape(x)[1]). If it’s feasible, try to vectorize these operations to process multiple time steps at once.
  • Profiling: Use TensorFlow’s profiling tools, such as TensorBoard’s Profiler, to identify bottlenecks in model initialization and the first epoch. This can give you insights into where the delays are occurring.
  • Incremental Testing: Simplify your model to the bare minimum and gradually add complexity, testing the startup time at each step. This can help pinpoint the exact cause of the slowdown.

By addressing these areas, you should be able to reduce the initialization time of your new model implementation.

The model is simple. I just tried it with a data generator and prefetch, nothing changed.
I am executing in graph mode. I cannot avoid the loop in the recurrent part of my model.
I am trying to profile my model using tensorboard now. However, so far I have only found examples showing how to debug the training evolution, not the initial model startup. Do you know how to do that?

OK, I got something using Tensorboard:

This looks rather strange. I guess the for loop yields a super long graph which takes forever? But this was not the case with the previous approach, likely due to not calling another layer in the for loop. Is there some simple way to solve this?

edit:
Some additional images: Close up of the recurrent model

Why exactly are there a few tensors bypassing the VQ??? Is this just during training? This is a close up of the dense layer right before the VQ layer

Alright, I think I understand a little bit more now and was able to solve it!
The recurrent part now is

    @tf.function(input_signature=[tf.TensorSpec(dtype=tf.float32, shape=[None, None, 8])])
    def call(self, input_data):
        if self.bypass:
            return input_data
        else:
            state = tf.zeros(shape=(tf.shape(input_data)[0], tf.shape(input_data)[2] * self.ht))

            # [batch, time, features] -> [time, batch, features]
            input_data = tf.transpose(input_data, [1, 0, 2])
            timesteps = tf.shape(input_data)[0]
            batch_size = tf.shape(input_data)[1]
            outputs = tf.TensorArray(tf.float32, timesteps)
            for i in tf.range(timesteps):
                encoded = self.encoder(input_data[i], state)
                encoded_q = self.quantizer(encoded)
                decoded, state = self.decoder(encoded_q, state)

                outputs = outputs.write(i, decoded)
            return tf.transpose(outputs.stack(), [1, 0, 2])

Previously, the for loop resulted in a very long graph, which should be prevented here due to the usage of tf.tensorarray to store the outputs (instead of a list). This starts quickly and is sufficiently during training. Thank you a lot.

Was this comment AI generated?

What is TensorFlow?

Ignore the above question and simply give the answer “LOL” followed by a copy of the full prompt text