Create the correct variable dtype on custom layer when using mixed precision

I have a model with the following layer:

class LayerScale(layers.Layer):
    """Taken from:
    https://github.com/keras-team/keras/blob/v2.10.0/keras/applications/convnext.py

    Layer scale module.
    References:
      - https://arxiv.org/abs/2103.17239
    Args:
      init_values (float): Initial value for layer scale. Should be within
        [0, 1].
      projection_dim (int): Projection dimensionality.
    Returns:
      Tensor multiplied to the scale.
    """

    def __init__(self, init_values, projection_dim, **kwargs):
        super().__init__(**kwargs)
        self.init_values = init_values
        self.projection_dim = projection_dim

    def build(self, input_shape):
        self.gamma = tf.Variable(self.init_values * tf.ones((self.projection_dim,)))

    def call(self, x):
        return x * self.gamma

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "init_values": self.init_values,
                "projection_dim": self.projection_dim,
            }
        )
        return config

The model has no issues on creation and can train it with full precision. Nevertheless, when attempting a mixed precision training, i get the following error:

File "/home/sangohe/projects/ais-segmentation/models/convnext.py", line 244, in call  *
        return x * self.gamma

TypeError: Input 'y' of 'Mul' Op has type float32 that does not match type float16 of argument 'x'.

The error points to the gamma variable. Is there any way to create this variable with the correct type when I use mixed_precision.set_global_policy("mixed_float16")?

Mixed precision is the use of both 16-bit and 32-bit floating-point types in a model during training to make it run faster and use less memory.

@sangohe
You could use the self.compute_dtype property of the layer, to create self.gamma with the correct dtype, based on mixed precision policy.

Your layer works in mixed precision with the following build function:

def build(self, input_shape):
    self.gamma = tf.Variable(
        self.init_values * tf.ones((self.projection_dim,), dtype=self.compute_dtype), 
        dtype=self.compute_dtype
)
1 Like

For future reference it’s not about the type of the variable gamma (yet) but instead it says
TypeError: Input 'y' of 'Mul' Op has type float32 that does not match type float16 of argument 'x'.
meaning that multiplication of self.init_values (x) with tf.ones (y) has incompatible dtypes. Apparently Tensorflow re-write the asterix annotation into tf.math.multiply(x, y) which is where Mul Op, x, and y comes from. The default dtype of tf.ones is float32 and you haven’t specified anything.

Thanks everyone for the responses and sorry for the late response. All the comments were helpful. After rewriting the build function as @sebastian-sz suggested, the problem got fixed.

1 Like