How to redirect gradients in a keras model?

Hello everyone, I’m trying to define a custom layer in my Keras model. It’s purpose is to take a motion prediction vector and apply it to an image. This is done using image warping. The warped frame is then used in the loss function to see how well it matches the second frame in the series. The model should in theory try and find the best motion vectors to describe this transition, from frame T to frame T+1.

However, as it is the model complains that there are no gradients. This is likely due to the operations in the custom layer not being differentiable. Which is fine. The only thing I need from those operations is the warped frame to guide the model.

So my question is, how do I prevent the gradients from flowing into this custom layer? I only need to update the CNN portions of the model with respect to the warp frame and the loss.

class WarpFrameLayer(Layer):
    def __init__(self, **kwargs):
        super(WarpFrameLayer, self).__init__(**kwargs)

    def call(self, inputs):
        frame, dx_dy = inputs
        warped_frame = self.warp_function(frame, dx_dy)
        return warped_frame

    @tf.function
    def warp_function(self, frame, dx_dy):
        num_images = tf.shape(frame)[0]
        translations = tf.reshape(dx_dy, [num_images, 2])
        zeros = tf.zeros([num_images, 1], dtype=tf.float32)
        ones = tf.ones([num_images, 1], dtype=tf.float32)
        transforms = tf.concat([ones, zeros, translations[:, 0:1], zeros, ones, translations[:, 1:2], zeros, zeros],
                               axis=1)
        output_shape = tf.shape(frame)[1:3]
        warped_frame = tf.raw_ops.ImageProjectiveTransformV3(images=frame, transforms=transforms,
                                                             output_shape=output_shape, interpolation="BILINEAR",
                                                             fill_value=0)
        return warped_frame

def camera_translation_model(input_shape):
    # Frame inputs
    frame_t = Input(shape=input_shape, name='frame_t')
    frame_t_plus_1 = Input(shape=input_shape, name='frame_t_plus_1')

    # Motion prediction model
    conv1_t = Conv2D(16, (3, 3), activation='relu', padding='same')(frame_t)
    conv1_t_plus_1 = Conv2D(16, (3, 3), activation='relu', padding='same')(frame_t_plus_1)
    concat_features = concatenate([conv1_t, conv1_t_plus_1])
    conv2 = Conv2D(32, (3, 3), activation='relu', padding='same')(concat_features)
    conv3 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv2)
    flattened = Flatten()(conv3)
    motion_pred = Dense(2, activation='linear', name='motion_pred')(flattened)

    warped_frame = WarpFrameLayer()([frame_t, motion_pred])

    # warped frame guides the model's motion estimation between frames
    model = Model(inputs=[frame_t, frame_t_plus_1], outputs=warped_frame)
    return model

To prevent gradients from flowing into the custom WarpFrameLayer in your Keras model, use tf.stop_gradient on the output of the warp operation. This function blocks the gradient flow through the warped frame, allowing the CNN to update based on the loss without considering the gradients from the non-differentiable warp operation.

pythonCopy code

def call(self, inputs):
    frame, dx_dy = inputs
    warped_frame = self.warp_function(frame, dx_dy)
    return tf.stop_gradient(warped_frame)

This modification ensures that only the CNN parts of the model are updated during backpropagation, based on how well the warped frame matches the subsequent frame in the series.