Where does the tape break after all?

Where does the tape break after all ?

**

Scenario

**

Let us consider the following hypothetical scenario :

  1. A dataset which gives (images,labels) pair.
  2. A model which outputs a tensor (logits) of shape (3,batch_size,C) where C is number of classes.
  3. Three losses are computed, one for each (i,batch_size,C) with i=0,1,2.

Let us break down the code instead of looking at one long one.

The following creates a simple model.

import tensorflow as tf


# Let's create a model first
class MyModel(tf.keras.Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self._l1 = tf.keras.layers.Conv2D(
            filters=64, kernel_size=3
        )
        self._l1 = tf.keras.Sequential(
            [
                tf.keras.layers.Conv2D(
                    filters=64, kernel_size=3
                ),
                tf.keras.layers.GlobalAveragePooling2D(keepdims=False),
                tf.keras.layers.Dense(units=10)
            ]
        )
        self._l2 = tf.keras.Sequential(
            [
                tf.keras.layers.Conv2D(
                    filters=64, kernel_size=3
                ),
                tf.keras.layers.GlobalAveragePooling2D(keepdims=False),
                tf.keras.layers.Dense(units=10)
            ]
        )
        self._l3 = tf.keras.Sequential(
            [
                tf.keras.layers.Conv2D(
                    filters=64, kernel_size=3
                ),
                tf.keras.layers.GlobalAveragePooling2D(keepdims=False),
                tf.keras.layers.Dense(units=10)
            ]
        )

    def call(self, inputs, training=None, mask=None):
        y1 = self._l1(inputs)
        y2 = self._l2(inputs)
        y3 = self._l3(inputs)
        out = tf.stack([y1, y2, y3])
        # Output shape is (3, batch_size, 10)

        return out

Then we create a fake random dataset for image classification with 10 classes.



model = MyModel()


# just a random image classification dataset
def get_dataset():
    db = tf.data.Dataset.range(100)
    db = db.map(
        lambda x: (tf.random.uniform(shape=(224, 224, 3), dtype=tf.float32),
                   tf.one_hot(tf.random.uniform(shape=(), minval=0, maxval=9, dtype=tf.int64), depth=10)
                   )
    )
    db = db.batch(5)
    return db
dataset = get_dataset()

Now below is a function we use to compute the individual losses. Note that it produces a list.

# Computes the loss. But gradient backpropagation fails.
# Here appending to a list is used to compute the loss of the output of l1, l2 and l3 ( See Model ) with respect to GT labels.
# This method should be avoided as it makes use of python list. TensorArray is recommended to be used here.

# Check out https://github.com/tensorflow/tensorflow/issues/37512
def get_loss_v1(logits, labels):
    losses = list()
    for i in range(3):
        y_pred = tf.gather(logits, i)
        losses.append(
            tf.reduce_mean(
                tf.keras.losses.categorical_crossentropy(
                    from_logits=True,
                    y_pred=y_pred,
                    y_true=labels
                )
            )
        )
    # losses = tf.stack(losses) # Uncommenting it will give gradients as None
    return losses

Below is a simple training step and main code which produces gradients correctly.

def train_step(images, labels):
    with tf.GradientTape(persistent=True) as tape:
        logits = model(images, training=True)
        # loss_val = get_loss_v1(logits, labels)
        print(loss_val)  # This prints the forward losses properly
    grads = list()
    for i in range(3):
        grads.append(
            tape.gradient(
                target=loss_val[i],
                sources=model.trainable_variables
            )
        )
    print(grads)  # This always prints None


for data in dataset:
    images, labels = data
    train_step(images, labels)

**

My Questions

**

Let us consider a set of changes in the code such that it will break. And that will clarify my questions

In computing tape.gradient(), if we replace loss_val[i] with tf.gather(loss_val,i), the code will break. I think I understand why. This replacement will be akin to introducing a new transformation that has not been watched by the tape. Am I correct about this one ?

At the bottom of the function get_loss_v1(...) uncomment the line losses=tf.stack(losses). Now this change lies well within the watch of the tape. And yet, gradients are None. I do not understand why this happens ?

As another small observation, let us bring in another alternative to compute the loss that we got using get_loss_v1().

# Computes the loss. But gradient backpropagation fails.
# Here tf.map_fn is used to compute the loss of the output of l1, l2 and l3 ( See Model ) with respect to GT labeles.
def get_loss_v2(logits, labels):
    loss = tf.map_fn(
        fn=lambda x: tf.reduce_mean(
            tf.keras.losses.categorical_crossentropy(
                from_logits=True,
                y_pred=logits[x, ...],
                y_true=labels
            )
        ),
        elems=tf.range(3),
        fn_output_signature=tf.float32
    )
    # loss = tf.stack(loss)
    return loss

And then we just do loss_val=get_loss_v2(logits,labels). Now, this should work because everything is watched by the gradient. Yet, the gradients are None. Why ?

Infact, get_loss_v2() seems the way to go as for loops can be parallelized and we are completely within the tensorflow domain without using any Python constructs.

So, why does the tape break after all ?