Error message for grouped convolution backprop on CPU is uninformative

Hello, I recently learned that gradients and backprop for grouped convolution is not supported on CPU, as discussed in the following github threads:

Nothing on the documentation for convolution with grouping indicated I would run into this issue, and the error message that came up when I attempted to run the optimizer (copied below) was very unhelpful for diagnosing the source of the issue.

InvalidArgumentError: Computed input depth 32 doesn't match filter input depth 4 [Op:Conv2DBackpropInput]

I did not expect to see a backprop operation for Conv2D when my model was only using Conv1D layers, and I only figured out the issue was related to grouping because I was very familiar with the parameters of the model I was running.

I believe adding the following features would make it much easier for future developers to diagnose the issue:

  • A more informative error message when you attempt to run backprop on a grouped convolution layer on CPU
  • Possibly a warning when you load a grouped convolution layer on CPU at all
  • A note in the documentation for convolution layers that gradients for grouped convolution is only supported on GPU

Thank you.

1 Like

@kqlu4156,

Gradients for grouped convolutions are supported on CPU after Tensorflow 2.9.

import tensorflow as tf
print(tf.__version__)

inputs = tf.random.normal(shape = (4,10,128))
conv1d = tf.keras.layers.Conv1D(32, (3), groups = 4, padding = 'same') 
with tf.GradientTape() as tape:
  outputs = conv1d(inputs)
  grads_1d = tape.gradient(outputs, conv1d.trainable_variables)

print(grads_1d)

Output:

2.12.0
[<tf.Tensor: shape=(3, 32, 32), dtype=float32, numpy=
array([[[ 5.6082282e+00,  5.6082282e+00,  5.6082282e+00, ...,
         -5.7246876e+00, -5.7246876e+00, -5.7246876e+00],
        [-5.2481537e+00, -5.2481537e+00, -5.2481537e+00, ...,
         -1.5736814e+00, -1.5736814e+00, -1.5736814e+00],
        [ 9.2298737e+00,  9.2298737e+00,  9.2298737e+00, ...,
         -7.0252337e+00, -7.0252337e+00, -7.0252337e+00],
        ...,
        [-8.7338948e-01, -8.7338948e-01, -8.7338948e-01, ...,
          7.0602713e+00,  7.0602713e+00,  7.0602713e+00],
        [ 1.1597365e+00,  1.1597365e+00,  1.1597365e+00, ...,
          1.1881893e+00,  1.1881893e+00,  1.1881893e+00],
        [-4.6490440e+00, -4.6490440e+00, -4.6490440e+00, ...,
         -1.7533340e+01, -1.7533340e+01, -1.7533340e+01]],

       [[ 8.1101370e-01,  8.1101370e-01,  8.1101370e-01, ...,
         -3.8097646e+00, -3.8097646e+00, -3.8097646e+00],
        [-4.1635637e+00, -4.1635637e+00, -4.1635637e+00, ...,
         -9.5894051e-01, -9.5894051e-01, -9.5894051e-01],
        [ 1.2245650e+01,  1.2245650e+01,  1.2245650e+01, ...,
         -6.6381464e+00, -6.6381464e+00, -6.6381464e+00],
        ...,
        [ 1.6338730e+00,  1.6338730e+00,  1.6338730e+00, ...,
          7.8838940e+00,  7.8838940e+00,  7.8838940e+00],
        [ 1.5161037e-02,  1.5161037e-02,  1.5161037e-02, ...,
          1.5670588e+00,  1.5670588e+00,  1.5670588e+00],
        [-3.5443621e+00, -3.5443621e+00, -3.5443621e+00, ...,
         -1.6055473e+01, -1.6055473e+01, -1.6055473e+01]],

       [[-1.3902726e+00, -1.3902726e+00, -1.3902726e+00, ...,
         -4.0715046e+00, -4.0715046e+00, -4.0715046e+00],
        [-6.2915697e+00, -6.2915697e+00, -6.2915697e+00, ...,
         -1.0634291e+00, -1.0634291e+00, -1.0634291e+00],
        [ 1.0099386e+01,  1.0099386e+01,  1.0099386e+01, ...,
         -5.7550478e+00, -5.7550478e+00, -5.7550478e+00],
        ...,
        [ 2.5703638e+00,  2.5703638e+00,  2.5703638e+00, ...,
          7.6247263e+00,  7.6247263e+00,  7.6247263e+00],
        [ 2.0062921e+00,  2.0062921e+00,  2.0062921e+00, ...,
          6.0164762e-01,  6.0164762e-01,  6.0164762e-01],
        [-3.8103375e+00, -3.8103375e+00, -3.8103375e+00, ...,
         -1.3640256e+01, -1.3640256e+01, -1.3640256e+01]]], dtype=float32)>, <tf.Tensor: shape=(32,), dtype=float32, numpy=
array([40., 40., 40., 40., 40., 40., 40., 40., 40., 40., 40., 40., 40.,
       40., 40., 40., 40., 40., 40., 40., 40., 40., 40., 40., 40., 40.,
       40., 40., 40., 40., 40., 40.], dtype=float32)>]

Please find the gist for reference.

Thank you!