Hello there :wave:
Today I ran into a cumbersome error that only happens whe…n running on CPU instead of GPUs. I tracked the source of the error to grouped convolutions and managed to make a reproducible minimal snippet. I happened to suspect that it was because of grouped convolutions since I ran into some problems a few days ago with those using SavedModels but it's pure luck.
It would be good to improve the error message or even get this fixed if possible :pray:
Happy to help provided some directions!
**System information**
- Have I written custom code: yes, the code snippet
- OS Platform and Distribution: Linux Ubuntu 20.04
- TensorFlow installed from: binary, via pip
- TensorFlow version: 2.5.0
- Python version: 3.8
- CUDA/cuDNN version: CUDA 11.4 (cuDNN 8.2.0)
- GPU model and memory: NVIDIA GeForce RTX 2070 with Max-Q Design
**Describe the current behavior**
As of now, running the snippet further down below throws an error on CPU but not on GPU.
**Describe the expected behavior**
Simple:
- having a better error (pointing the lack of support of grouped convolutions on CPU)
- or even better, if that could get fixed :)
**Standalone code to reproduce the issue**
```python
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
samples = tf.zeros((1, 256, 256, 3), dtype=tf.float32)
model = Sequential([layers.Conv2D(18, padding='same', kernel_size=3, groups=1), layers.GlobalAveragePooling2D(), layers.Dense(1)])
trouble_model = Sequential([layers.Conv2D(18, padding='same', kernel_size=3, groups=3), layers.GlobalAveragePooling2D(), layers.Dense(1)])
# Backprop on classic model
with tf.GradientTape() as tape:
out = model(samples, training=True)
grads = tape.gradient(out, model.trainable_weights)
# Now with grouped conv
with tf.GradientTape() as tape:
out = trouble_model(samples, training=True)
grads = tape.gradient(out, trouble_model.trainable_weights)
```
which runs successfully on GPU but on CPU throws the following:
```
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
<ipython-input-1-e03a8706f9a2> in <module>
19 with tf.GradientTape() as tape:
20 out = trouble_model(samples, training=True)
---> 21 grads = tape.gradient(out, trouble_model.trainable_weights)
~/miniconda3/lib/python3.8/site-packages/tensorflow/python/eager/backprop.py in gradient(self, target, sources, output_gradients, unconnected_gradients)
1072 for x in nest.flatten(output_gradients)]
1073
-> 1074 flat_grad = imperative_grad.imperative_grad(
1075 self._tape,
1076 flat_targets,
~/miniconda3/lib/python3.8/site-packages/tensorflow/python/eager/imperative_grad.py in imperative_grad(tape, target, sources, output_gradients, sources_raw, unconnected_gradients)
69 "Unknown value for unconnected_gradients: %r" % unconnected_gradients)
70
---> 71 return pywrap_tfe.TFE_Py_TapeGradient(
72 tape._tape, # pylint: disable=protected-access
73 target,
~/miniconda3/lib/python3.8/site-packages/tensorflow/python/eager/backprop.py in _gradient_function(op_name, attr_tuple, num_inputs, inputs, outputs, out_grads, skip_input_indices, forward_pass_name_scope)
157 gradient_name_scope += forward_pass_name_scope + "/"
158 with ops.name_scope(gradient_name_scope):
--> 159 return grad_fn(mock_op, *out_grads)
160 else:
161 return grad_fn(mock_op, *out_grads)
~/miniconda3/lib/python3.8/site-packages/tensorflow/python/ops/nn_grad.py in _Conv2DGrad(op, grad)
579 # in Eager mode.
580 return [
--> 581 gen_nn_ops.conv2d_backprop_input(
582 shape_0,
583 op.inputs[1],
~/miniconda3/lib/python3.8/site-packages/tensorflow/python/ops/gen_nn_ops.py in conv2d_backprop_input(input_sizes, filter, out_backprop, strides, padding, use_cudnn_on_gpu, explicit_paddings, data_format, dilations, name)
1245 return _result
1246 except _core._NotOkStatusException as e:
-> 1247 _ops.raise_from_not_ok_status(e, name)
1248 except _core._FallbackException:
1249 pass
~/miniconda3/lib/python3.8/site-packages/tensorflow/python/framework/ops.py in raise_from_not_ok_status(e, name)
6895 message = e.message + (" name: " + name if name is not None else "")
6896 # pylint: disable=protected-access
-> 6897 six.raise_from(core._status_to_exception(e.code, message), None)
6898 # pylint: enable=protected-access
6899
~/miniconda3/lib/python3.8/site-packages/six.py in raise_from(value, from_value)
InvalidArgumentError: Computed input depth 3 doesn't match filter input depth 1 [Op:Conv2DBackpropInput]
```