Keras with DTensor - gradient errors

Hello all,

I’m taking an existing Keras model and trying to get it to run with data parallelism. I’m basing it on what I see here.

My model definition looks like this:

               with tf.keras.dtensor.experimental.layout_map_scope(self.layout_map):
                    self.inputs = keras.Input(shape=self.input_size, sparse=True)
                    self.hiddens = layers.Dense(
                        100,
                        activation="relu",
                        kernel_regularizer=regularizers.L1L2(l1=L1_REG),
                        bias_regularizer=regularizers.L1L2(l1=L1_REG),
                        input_shape=(self.input_size,),
                        name="feature",
                    )
                    self.hiddens_layer = self.hiddens(self.inputs)
                    self.dropout = layers.Dropout(DROPOUT_RATE)(self.hiddens_layer)
                    self.outputs = layers.Dense(
                        1,
                        activation="sigmoid",
                        kernel_regularizer=regularizers.L1L2(l1=L1_REG),
                        bias_regularizer=regularizers.L1L2(l1=L1_REG),
                        name="feature_2",
                    )(self.hiddens_layer)
         

I’m writing a custom training routine, as per the above docs, and my training step looks like this:

    @tf.function
    def train_step(self, x, y, w, optimizer, metrics):
        with tf.GradientTape() as tape:
            logits = self.model(x, training=True)
            loss = tf.reduce_sum(tf.math.multiply(
                tf.keras.losses.binary_crossentropy(
                    y, logits, from_logits=True), w))

        gradients = tape.gradient(loss, self.model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))

        for metric in metrics.values():
            metric.update_state(y_true=y, y_pred=logits)

        loss_per_sample = loss / len(x)
        results = {'loss': loss_per_sample}
        return results

I’m getting the following error:

File "training/model.py", line 266, in train_step *
optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))

....

File "/usr/local/lib/python3.8/dist-packages/keras/optimizers/optimizer_experimental/adam.py", line 175, in update_step
m.assign_add((gradient - m) * (1 - self.beta_1))

 ValueError: Dimensions must be equal, but are 100 and 0 for '{{node sub_2}} = Sub[T=DT_FLOAT](gradient, sub_2/ReadVariableOp)' with input shapes: [5138741,100], [0].

That input shape matches the variables on my first layer (input size ~5M, 100 hidden units), and the trainable variable tensors seem to match the shape of the gradients (they are of type AutoCastVariable). I can’t figure out what’s missing here. Anyone have an idea what might be going on here? TIA