Distribute Strategy with Keras Custom Loops

I posted this question on Stack Overflow last week, but it didn’t get any engagement, so I’m hoping that posting here will help me connect with someone who can answer my questions.

I am working with distribute strategy scopes using custom training loops with Keras models. Consider this script which closely follows this tutorial: Entrenamiento personalizado con tf.distribute.Strategy  |  TensorFlow Core

import os

import numpy as np
import tensorflow as tf

print(tf.__version__)

fashion_mnist = tf.keras.datasets.fashion_mnist

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

# Adding a dimension to the array -> new shape == (28, 28, 1)
# We are doing this because the first layer in our model is a convolutional
# layer and it requires a 4D input (batch_size, height, width, channels).
# batch_size dimension will be added later on.
train_images = train_images[..., None]
test_images = test_images[..., None]

# Getting the images in [0, 1] range.
train_images = train_images / np.float32(255)
test_images = test_images / np.float32(255)

CREATE_MODEL_WITH_SCOPE = False
CREATE_MODEL_SEQUENTIAL = True

input_image = np.random.random((28, 28, 1))
target_image = np.random.random((28, 28, 1))

def data_generator():
    while True:
        yield input_image, target_image


# If the list of devices is not specified in the
# `tf.distribute.MirroredStrategy` constructor, it will be auto-detected.
strategy = tf.distribute.MirroredStrategy()

print('Number of devices: {}'.format(strategy.num_replicas_in_sync))

BUFFER_SIZE = 60000  # len(train_images)

BATCH_SIZE_PER_REPLICA = 64
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

EPOCHS = 10

train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels)).shuffle(BUFFER_SIZE).batch(
    GLOBAL_BATCH_SIZE)
test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE)

train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
test_dist_dataset = strategy.experimental_distribute_dataset(test_dataset)

# Create a checkpoint directory to store the checkpoints.
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")


class MyTask:
    def __init__(self, strategy):
        self.strategy = strategy
        with self.strategy.scope():
            # Set reduction to `none` so we can do the reduction afterwards and divide by
            # global batch size.
            self.loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
              from_logits=True,
              reduction=tf.keras.losses.Reduction.NONE
            )

        with self.strategy.scope():
            self.test_loss = tf.keras.metrics.Mean(name='test_loss')
            self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
                name='train_accuracy')
            self.test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
                name='test_accuracy')

        # model, optimizer, and checkpoint must be created under `strategy.scope`.
        if CREATE_MODEL_WITH_SCOPE:
            with self.strategy.scope():
                self.model = self.create_model()
        else:
            self.model = self.create_model()
        with self.strategy.scope():
            self.optimizer = tf.keras.optimizers.Adam()
            self.checkpoint = tf.train.Checkpoint(optimizer=self.optimizer, model=self.model)

        self.discriminator = self.create_model()

    def create_model(self):
        input = tf.keras.Input(
            shape=(28, 28, 1),
            name="input"
        )
        layers = tf.keras.Sequential([
            tf.keras.layers.Conv2D(32, 3, activation='relu'),
            tf.keras.layers.MaxPooling2D(),
            tf.keras.layers.Conv2D(64, 3, activation='relu'),
            tf.keras.layers.MaxPooling2D(),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(64, activation='relu'),
            tf.keras.layers.Dense(10)
        ])
        if CREATE_MODEL_SEQUENTIAL:
            return layers
        output = layers(input)
        model = tf.keras.Model(
            inputs=input,
            outputs=output
        )
        return model

    def create_model_sequential(self):
        model = tf.keras.Sequential([
            tf.keras.layers.Conv2D(32, 3, activation='relu'),
            tf.keras.layers.MaxPooling2D(),
            tf.keras.layers.Conv2D(64, 3, activation='relu'),
            tf.keras.layers.MaxPooling2D(),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(64, activation='relu'),
            tf.keras.layers.Dense(10)
        ])
        return model

    def train_step(self, inputs):
        images, labels = inputs

        with tf.GradientTape() as tape:
            predictions = self.model(images, training=True)
            loss = self.compute_loss(labels, predictions)
        gradients = tape.gradient(loss, self.model.trainable_variables)
        assert len(gradients) == len(self.model.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))

        # train_accuracy.update_state(labels, predictions)
        return loss

    def compute_loss(self, labels, predictions):
        per_example_loss = self.loss_object(labels, predictions)
        return tf.nn.compute_average_loss(per_example_loss, global_batch_size=GLOBAL_BATCH_SIZE)

    def test_step(self, inputs):
        images, labels = inputs

        predictions = self.model(images, training=False)
        t_loss = self.loss_object(labels, predictions)

        self.test_loss.update_state(t_loss)
        self.test_accuracy.update_state(labels, predictions)

    # `run` replicates the provided computation and runs it
    # with the distributed input.
    @tf.function
    def distributed_train_step(self, dataset_inputs):
        per_replica_losses = strategy.run(self.train_step, args=(dataset_inputs,))
        reduced_losses = strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
        return reduced_losses

    @tf.function
    def distributed_test_step(self, dataset_inputs):
        return strategy.run(self.test_step, args=(dataset_inputs,))

    def fit(self):
        for epoch in range(EPOCHS):
            # TRAIN LOOP
            total_loss = 0
            num_batches = 0
            for x in train_dist_dataset:
                new_loss = self.train_step(x)
                total_loss += new_loss
                num_batches += 1
            train_loss = total_loss / num_batches

            # TEST LOOP
            for x in test_dist_dataset:
                self.distributed_test_step(x)

            if epoch % 2 == 0:
                self.checkpoint.save(checkpoint_prefix)

            template = ("Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, "
                        "Test Accuracy: {}")
            print(template.format(epoch + 1, train_loss,
                                  self.train_accuracy.result() * 100, self.test_loss.result(),
                                  self.test_accuracy.result() * 100))

            # template = ("Epoch {}, Loss: {}")
            # print (template.format(epoch+1, train_loss['loss1']))

            self.test_loss.reset_states()
            self.train_accuracy.reset_states()
            self.test_accuracy.reset_states()

task = MyTask(strategy)
task.fit()

eval_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
    name='eval_accuracy')

new_model = task.create_model()
new_optimizer = tf.keras.optimizers.Adam()

test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE)


@tf.function
def eval_step(images, labels):
    predictions = new_model(images, training=False)
    eval_accuracy(labels, predictions)


checkpoint = tf.train.Checkpoint(optimizer=new_optimizer, model=new_model)
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

for images, labels in test_dataset:
    eval_step(images, labels)

print('Accuracy after restoring the saved model without strategy: {}'.format(
    eval_accuracy.result() * 100))

There are two flags in this script that control the behaviour:

  • CREATE_MODEL_WITH_SCOPE controls whether the model is created under a with strategy.scope():. If this flag is false, no explicit scope is used when creating the model.
  • CREATE_MODEL_SEQUENTIAL controls whether the model is a keras.Sequential model. If this flag is False, the model is wrapped with the Keras functional API, but it is otherwise the same model

Depending on the combination of flags that I use and the tensorflow environment, this script either works or produces an error such as this one:

Traceback (most recent call last):
  File "scratches/scratch_85.py", line 194, in <module>
    task.fit()
  File "scratches/scratch_85.py", line 168, in fit
    new_loss = self.train_step(x)
  File "scratches/scratch_85.py", line 132, in train_step
    self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
  File "anaconda3/envs/tf_2.4/lib/python3.8/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py", line 604, in apply_gradients
    self._create_all_weights(var_list)
  File "anaconda3/envs/tf_2.4/lib/python3.8/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py", line 783, in _create_all_weights
    self._create_slots(var_list)
  File "anaconda3/envs/tf_2.4/lib/python3.8/site-packages/tensorflow/python/keras/optimizer_v2/adam.py", line 127, in _create_slots
    self.add_slot(var, 'm')
  File "anaconda3/envs/tf_2.4/lib/python3.8/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py", line 838, in add_slot
    raise ValueError(
ValueError: Trying to create optimizer slot variable under the scope for tf.distribute.Strategy (<tensorflow.python.distribute.distribute_lib._DefaultDistributionStrategy object at 0x7f597f1eadc0>), which is different from the scope used for the original variable (MirroredVariable:{
  0: <tf.Variable 'conv2d/kernel:0' shape=(3, 3, 1, 32) dtype=float32, numpy=
array(...)
}). Make sure the slot variables are created under the same strategy scope. This may happen if you're restoring from a checkpoint outside the scope

The expected output from the script looks like this:

Epoch 1, Loss: 0.5163882374763489, Accuracy: 0.0, Test Loss: 0.4203348755836487, Test Accuracy: 85.15999603271484
Epoch 2, Loss: 0.342644065618515, Accuracy: 0.0, Test Loss: 0.36492371559143066, Test Accuracy: 86.51000213623047
Epoch 3, Loss: 0.2957099378108978, Accuracy: 0.0, Test Loss: 0.30147236585617065, Test Accuracy: 89.18000030517578
Epoch 4, Loss: 0.2637444734573364, Accuracy: 0.0, Test Loss: 0.2926381230354309, Test Accuracy: 89.77000427246094
Epoch 5, Loss: 0.24089021980762482, Accuracy: 0.0, Test Loss: 0.2793895900249481, Test Accuracy: 90.36000061035156
Epoch 6, Loss: 0.221912682056427, Accuracy: 0.0, Test Loss: 0.26250553131103516, Test Accuracy: 90.30999755859375
Epoch 7, Loss: 0.20427824556827545, Accuracy: 0.0, Test Loss: 0.2791960835456848, Test Accuracy: 90.05000305175781
Epoch 8, Loss: 0.18922416865825653, Accuracy: 0.0, Test Loss: 0.24758891761302948, Test Accuracy: 91.23999786376953
Epoch 9, Loss: 0.17512068152427673, Accuracy: 0.0, Test Loss: 0.24345894157886505, Test Accuracy: 91.27999877929688
Epoch 10, Loss: 0.16123878955841064, Accuracy: 0.0, Test Loss: 0.23703424632549286, Test Accuracy: 91.55999755859375
Accuracy after restoring the saved model without strategy: 91.27999877929688

Here’s what I’ve observed on two different environments:

  • tf_2.4 = Tensorflow 2.4.1, python 3.8, CUDA 10.1.243

  • tf_2.5 = Tensorflow 2.5.0, python 3.8, CUDA 11.2.2

  • CREATE_MODEL_WITH_SCOPE=False, CREATE_MODEL_SEQUENTIAL=False

    • tf_2.4: Works correctly
    • tf_2.5: ValueError: Trying to create optimizer slot variable under the scope…
  • CREATE_MODEL_WITH_SCOPE=True, CREATE_MODEL_SEQUENTIAL=False

    • tf_2.4: ValueError: Trying to create optimizer slot variable under the scope
    • tf_2.5: Works correctly
  • CREATE_MODEL_WITH_SCOPE=True, CREATE_MODEL_SEQUENTIAL=True

    • tf_2.4: Works correctly
    • tf_2.5: ValueError: Trying to create optimizer slot variable under the scope
  • CREATE_MODEL_WITH_SCOPE=False, CREATE_MODEL_SEQUENTIAL=True

    • tf_2.4: Works correctly
    • tf_2.5: ValueError: Trying to create optimizer slot variable under the scope

From the documentation, I think that CREATE_MODEL_WITH_SCOPE should be True. But, I would like to understand what is going on in the various cases. When are the slot variables being created, and how can I make sure they have the proper scope? Why is there a different behaviour between the Sequential and Functional versions of the same model? Why do these two versions of tensorflow produce opposite results on these tests?

Any help in understanding these questions would be greatly appreciated. This question arose because I want to use a custom loop with a more complicated Keras Functional API model, but I encounter this slot variable error and I don’t know how to work around it. I can’t figure out how to manage the scopes and there’s no way to replace the model with a Sequential one.

@lgusm Do we have subscribers to the distributed-training tag?

I don’t know, good question.

I’ll try to find someone to look into this

1 Like

The error indicates that the variable and its corresponding slot variable (such as momentum) in optimizer are created in different tf.distribute strategies’ scope. In your case, it will error out if your model is not created under strategy.scope but your optimizer is. To solve it, you can use CREATE_MODEL_WITH_SCOPE to control whether you want to create optimizer under the scope or not.

4 Likes

Thanks for a swift reply @Yuefeng_Zhou :+1:

@QEDan We’ve reached out to the tf.distribute.Strategy team (cc @billy). Hope this answers (some) of your question(s) and don’t hesitate to share your experience/provide more feedback.

@thea Should we get a team badge for @Yuefeng_Zhou ?

Thanks everyone

2 Likes

Thank you for the response. As I mentioned in the post, I expect that I should get the correct behaviour from this script when CREATE_MODEL_WITH_SCOPE=True, but not necessarily when CREATE_MODEL_WITH_SCOPE=False for the reasons you outlined.

In the case of CREATE_MODEL_WITH_SCOPE=True, both the model and the optimizer are created under the same scope. In this case, under Tensorflow 2.4, I get the correct behaviour if I use a Sequential model, but not a model using the functional API. Under Tensorflow 2.5, I get the correct behaviour when using the functional API to build the model, but not when using a Sequential model.

So, I would still like to have an explanation for why variables are getting created under different scopes depending on the Keras API that I use to build the model, and also why different versions of Tensorflow have different behaviour.

1 Like

There must be some changes in the Keras API. Let me find someone from the Keras team.

2 Likes