Code is stuck at epoch 1 in lazy mode

my code below is stuck at epoch 1 and wont run in lazy mode due to the loop
“for i, routers_indice in enumerate(routers_indices)” without that loop it runs fine.
its taking too long to run in eager mode


class Baseline_cbr_mb(tf.keras.Model):
    mean_std_scores_fields = {
        "flow_traffic",
        "flow_packets",
        "flow_pkts_per_burst",
        "flow_bitrate_per_burst",
        "flow_packet_size",
        "flow_p90PktSize",
        "rate",
        "flow_ipg_mean",
        "ibg",
        "flow_ipg_var",
        "link_capacity",
    }
    mean_std_scores = None

    name = "ufpa_ericsson_cbr_mb"

    def __init__(self, override_mean_std_scores=None, name=None):
        super(Baseline_cbr_mb, self).__init__()

        self.iterations = 12
        self.path_state_dim = 16
        self.link_state_dim = 16

        if override_mean_std_scores is not None:
            self.set_mean_std_scores(override_mean_std_scores)
        if name is not None:
            assert type(name) == str, "name must be a string"
            self.name = name

        self.attention = tf.keras.Sequential(
            [tf.keras.layers.Input(shape=(None, None, self.path_state_dim)),
            tf.keras.layers.Dense(
                self.path_state_dim, activation=tf.keras.layers.LeakyReLU(alpha=0.01)    
            ),
            ]
        )

            # GRU Cells used in the Message Passing step
        self.path_update = tf.keras.layers.RNN(
            tf.keras.layers.GRUCell(self.path_state_dim, name="PathUpdate",
            ),
            return_sequences=True,
            return_state=True,
            name="PathUpdateRNN",
        )
        self.link_update = tf.keras.layers.GRUCell(
            self.link_state_dim, name="LinkUpdate",
        )

        self.flow_embedding = tf.keras.Sequential(
            [
                tf.keras.layers.Input(shape=13),
                tf.keras.layers.Dense(
                    self.path_state_dim, activation=tf.keras.activations.selu,
                    kernel_initializer='lecun_uniform',
                    ),
                tf.keras.layers.Dense(
                    self.path_state_dim, activation=tf.keras.activations.selu,
                    kernel_initializer='lecun_uniform',
                    )
            ],
            name="PathEmbedding",
        )
        self.switch_link_embedding = tf.keras.Sequential(
            [
                tf.keras.layers.Input(shape=1),
                tf.keras.layers.Dense(
                    self.link_state_dim, activation=tf.keras.activations.selu,
                kernel_initializer='lecun_uniform',
                ),
                tf.keras.layers.Dense(
                    self.link_state_dim, activation=tf.keras.activations.selu,
                kernel_initializer='lecun_uniform',
                ),
            ],
            name="SwitchLinkEmbedding",
        )


        self.router_link_embedding = tf.keras.Sequential(
            [
                tf.keras.layers.Input(shape=1),
                tf.keras.layers.Dense(
                    self.link_state_dim, activation=tf.keras.activations.selu,
                kernel_initializer='lecun_uniform',
                ),
                tf.keras.layers.Dense(
                    self.link_state_dim, activation=tf.keras.activations.selu,
                kernel_initializer='lecun_uniform',
                ),
            ],
            name="RouterLinkEmbedding",
        )
        self.link_embedding = tf.keras.Sequential(
            [
                tf.keras.layers.Input(shape=3),
                tf.keras.layers.Dense(
                    self.link_state_dim, activation=tf.keras.activations.selu,
                    kernel_initializer='lecun_uniform',
                    ),
                tf.keras.layers.Dense(
                    self.link_state_dim, activation=tf.keras.activations.selu,
                    kernel_initializer='lecun_uniform',
                    )
            ],
            name="LinkEmbedding",
        )

        self.readout_path = tf.keras.Sequential(
            [
                tf.keras.layers.Input(shape=(None, self.path_state_dim)),
                tf.keras.layers.Dense(
                    self.link_state_dim // 2, activation=tf.keras.activations.selu,
                    kernel_initializer='lecun_uniform',
                    ),
                tf.keras.layers.Dense(
                    self.link_state_dim // 4, activation=tf.keras.activations.selu,
                    kernel_initializer='lecun_uniform',
                    ),
                tf.keras.layers.Dense(1, activation=tf.keras.activations.softplus)
            ],
            name="PathReadout",
        )
    
    def set_mean_std_scores(self, override_mean_std_scores):
        assert (
            type(override_mean_std_scores) == dict
            and all(kk in override_mean_std_scores for kk in self.mean_std_scores_fields)
            and all(len(val) == 2 for val in override_mean_std_scores.values())
        ), "overriden mean-std dict is not valid!"
        self.mean_std_scores = override_mean_std_scores

    @tf.function
    def call(self, inputs):
        # Ensure that the min-max scores are set
        assert self.mean_std_scores is not None, "the model cannot be called before setting the min-max scores!"

        # Process raw inputs
        devices = inputs["devices"]

        flow_traffic = inputs["flow_traffic"]
        flow_packets = inputs["flow_packets"]
        global_delay = inputs["global_delay"]
        global_losses = inputs["global_losses"]
        max_link_load = inputs["max_link_load"]
        flow_pkt_per_burst = inputs["flow_pkts_per_burst"]
        flow_bitrate = inputs["flow_bitrate_per_burst"]
        flow_packet_size = inputs["flow_packet_size"]
        flow_type = inputs["flow_type"]
        flow_ipg_mean = inputs["flow_ipg_mean"]
        flow_length = inputs["flow_length"]
        ibg = inputs["ibg"]
        flow_p90pktsize = inputs["flow_p90PktSize"]
        cbr_rate = inputs["rate"]
        flow_ipg_var = inputs["flow_ipg_var"]
        link_capacity = inputs["link_capacity"]
        link_to_path = inputs["link_to_path"]
        path_to_link = inputs["path_to_link"]

        flow_pkt_size_normal = (flow_packet_size - self.mean_std_scores["flow_packet_size"][0]) \
                    * self.mean_std_scores["flow_packet_size"][1],

        path_gather_traffic = tf.gather(flow_traffic, path_to_link[:, :, 0])
        load = tf.math.reduce_sum(path_gather_traffic, axis=1) / (link_capacity * 1e9)
        normal_load = tf.math.divide(load, tf.squeeze(max_link_load))
        # normal_load = tf.math.divide(load, max_link_load)

        # Initialize the initial hidden state for paths
        path_state = self.flow_embedding(
            tf.concat(
                [
                    (flow_traffic - self.mean_std_scores["flow_traffic"][0])
                    * self.mean_std_scores["flow_traffic"][1],
                    (flow_packets - self.mean_std_scores["flow_packets"][0])
                    * self.mean_std_scores["flow_packets"][1],
                    (ibg - self.mean_std_scores["ibg"][0])
                    * self.mean_std_scores["ibg"][1],
                    (cbr_rate - self.mean_std_scores["rate"][0])
                    * self.mean_std_scores["rate"][1],
                    (flow_p90pktsize - self.mean_std_scores["flow_p90PktSize"][0])
                    * self.mean_std_scores["flow_p90PktSize"][1],
                    (flow_packet_size - self.mean_std_scores["flow_packet_size"][0])
                    * self.mean_std_scores["flow_packet_size"][1],
                    (flow_bitrate - self.mean_std_scores["flow_bitrate_per_burst"][0])
                    * self.mean_std_scores["flow_bitrate_per_burst"][1],
                    (flow_ipg_mean - self.mean_std_scores["flow_ipg_mean"][0])
                    * self.mean_std_scores["flow_ipg_mean"][1],
                    (flow_ipg_var - self.mean_std_scores["flow_ipg_var"][0])
                    * self.mean_std_scores["flow_ipg_var"][1],
                    (flow_pkt_per_burst - self.mean_std_scores["flow_pkts_per_burst"][0])
                    * self.mean_std_scores["flow_pkts_per_burst"][1],
                    tf.expand_dims(tf.cast(flow_length, dtype=tf.float32), 1),
                    flow_type
                ],
                axis=1,
            )
        )

        # Initialize the initial hidden state for links
        link_state = self.link_embedding(
            tf.concat(
                [
                   (link_capacity - self.mean_std_scores["link_capacity"][0])
                    * self.mean_std_scores["link_capacity"][1],
                    load,
                    normal_load,
                ],
                axis=1,
            ),
        )

      
        all_devices_used = tf.gather(link_capacity, link_to_path, name="LinkToPath")
        routers_used = tf.ragged.boolean_mask(all_devices_used, tf.cast(devices, tf.bool))
        switches_used = tf.ragged.boolean_mask(all_devices_used, ~tf.cast(devices, tf.bool))
        routers_indices = tf.ragged.boolean_mask(link_to_path, tf.cast(devices, tf.bool))
        switch_indices = tf.ragged.boolean_mask(link_to_path, ~tf.cast(devices, tf.bool))

    
        routers_embeddings  = self.router_link_embedding(routers_used)
        switch_embeddings  = self.switch_link_embedding(switches_used) 
  
        link_state_variable = tf.Variable(initial_value=link_state)

        for i, routers_indice in enumerate(routers_indices):
            link_state_variable = link_state_variable.scatter_nd_update(tf.expand_dims(routers_indice, 1), routers_embeddings[i])
            link_state_variable = link_state_variable.scatter_nd_update(tf.expand_dims(switch_indices[i], 1), switch_embeddings[i])

        link_state = link_state_variable.read_value()

        # for i, routers_indice in enumerate(routers_indices):
        #     link_state = tf.tensor_scatter_nd_update(link_state, tf.expand_dims(routers_indice, 1), routers_embeddings[i])
        #     link_state = tf.tensor_scatter_nd_update(link_state, tf.expand_dims(switch_indices[i], 1), switch_embeddings[i])
        
        # # for i in range()

        # Iterate t times doing the message passing
        for _ in range(self.iterations):
            ####################
            #  LINKS TO PATH   #
            ####################
            
            link_gather = tf.gather(link_state, link_to_path, name="LinkToPath")

            previous_path_state = path_state
            path_state_sequence, path_state = self.path_update(
                link_gather, initial_state=path_state
            )
            
            # We select the element in path_state_sequence so that it corresponds to the state before the link was considered
            path_state_sequence = tf.concat(
                [tf.expand_dims(previous_path_state, 1), path_state_sequence], axis=1
            )
            
            ###################
            #   PATH TO LINK  #
            ###################
            path_gather = tf.gather_nd(
                path_state_sequence, path_to_link, name="PathToLink"
            )
            
            attention_coef = self.attention(path_gather)
            normalized_score = K.softmax(attention_coef)
            weighted_score = normalized_score * path_gather
            
            path_gather_score = tf.math.reduce_sum(weighted_score, axis=1)
            
            link_state, _ = self.link_update(path_gather_score, states=link_state)

        ################
        #  READOUT     #
        ################

        occupancy = self.readout_path(path_state_sequence[:, 1:])

        capacity_gather = tf.gather(link_capacity, link_to_path)
        
        queue_delay = occupancy / capacity_gather
        queue_delay = tf.math.reduce_sum(queue_delay, axis=1)

        return queue_delay

@blessedg
It looks like the issue might be related to the use of tf.Variable and the loop in the following part of your code:

link_state_variable = tf.Variable(initial_value=link_state)

for i, routers_indice in enumerate(routers_indices):
    link_state_variable = link_state_variable.scatter_nd_update(tf.expand_dims(routers_indice, 1), routers_embeddings[i])
    link_state_variable = link_state_variable.scatter_nd_update(tf.expand_dims(switch_indices[i], 1), switch_embeddings[i])

link_state = link_state_variable.read_value()

Using a loop with tf.Variable and scatter_nd_update operations can be problematic, especially when using TensorFlow in eager mode. It might be causing issues with the automatic differentiation and graph building, leading to slow execution.

Instead of using a loop for updating the link_state_variable, you can try using vectorized operations. Try doing this to avoid the loop:

link_indices = tf.concat([routers_indices, switch_indices], axis=1)
link_embeddings = tf.concat([routers_embeddings, switch_embeddings], axis=0)

link_state_variable = tf.tensor_scatter_nd_add(link_state_variable, link_indices, link_embeddings)

Replace the loop with this block of code and see if it resolves the issue. If not, let me know and we can sort this out together.

It returns this error

link_gather_z = tf.tensor_scatter_nd_update(link_gather_z, routers_indices, router)

ValueError: Exception encountered when calling layer ‘ufpa_ericsson_cbr_mb’ (type Baseline_cbr_mb).

TypeError: object of type ‘RaggedTensor’ has no len()

This is because routers_indices and switch indices are ragged tensors.

However, I have resolved it using

    i = tf.constant(0)
    link_state_variable = link_state  # Assuming link_state is defined

    # Condition function for the while loop
    def condition(i, link_state_variable):
        return i < tf.shape(routers_indices)[0]

    # Body function for the while loop
    def body(i, link_state_variable):
        routers_indice = tf.expand_dims(routers_indices[i], 1)
        switch_indice = tf.expand_dims(switch_indices[i], 1)

        # Update link_state using tensor_scatter_nd_update
        link_state_variable = tf.tensor_scatter_nd_update(link_state_variable, routers_indice, routers_embeddings[i])
        link_state_variable = tf.tensor_scatter_nd_update(link_state_variable, switch_indice, switch_embeddings[i])

        return i + 1, link_state_variable

    # Run the while loop
    final_i, link_state = tf.while_loop(
        condition, body, loop_vars=[i, link_state_variable], parallel_iterations=1
    )

@BadarJaffer

The problem i am currently facing is i have a ragged tensor of shape link_gather = (72, None, 16) each row is of varying length e.g row1 could be [7 x16] and row 2 could be [5 x 16].
I have a router_embedding (1x16) and switch_embedding (1x16). I also have ragged tensor device of shape 72xNone. it contains 0’s and 1s showing wether each item in link_gather is a router or switch.
How do I expand each row of link_gather with a router_embedding or switch embedding based on the device tensor. So that, for example the first row now becomes 14x 16 instead of 7x16 by interleaving the switch and row embeddings based on the device ragged tensor

@blessedg you can use the tf.ragged.map_flat_values function. This function applies a function to each scalar value in a ragged tensor while preserving the ragged structure. Here’s how you can achieve the desired result:

import tensorflow as tf

# Assuming you have the following inputs
link_gather = ...  # Shape (72, None, 16)
router_embedding = ...  # Shape (1, 16)
switch_embedding = ...  # Shape (1, 16)
device = ...  # Shape (72, None)

# Convert the router_embedding and switch_embedding to ragged tensors
router_embedding_ragged = tf.RaggedTensor.from_row_splits(router_embedding, [0, 1])
switch_embedding_ragged = tf.RaggedTensor.from_row_splits(switch_embedding, [0, 1])

# Create a function to expand each row of link_gather
def expand_row(row, is_router):
    # Choose the appropriate embedding based on is_router
    embedding = router_embedding_ragged if is_router else switch_embedding_ragged
    # Repeat the embedding along the row
    repeated_embedding = tf.repeat(embedding, row.row_lengths(), axis=0)
    return repeated_embedding

# Use tf.ragged.map_flat_values to apply the function to each row
expanded_link_gather = tf.ragged.map_flat_values(expand_row, link_gather, device)

# Print the result
print(expanded_link_gather)

Make sure to adapt the code according to the actual shapes and values of your inputs.

Thank you so much for the help @BadarJaffer

1 Like