Custom Gradient for Sparse Weight Tensors

I am trying to create a custom layer which computes W*x + b where W is a sparse tensor. It is important that I don’t ever form the dense version of W because it would be too large to store in memory. It is my understanding that the computation of W*x, using tf.sparse.sparse_dense_matmul(W, x), does not have a supported gradient.

To make this work, I am trying to implement a custom gradient using the following code:

@tf.custom_gradient
      def sparse_weight_multiply(self, w):
	     # compute the product sparse_W * inputs, where sparse_W is a sparse tensor formed from the entries in w

          self.sparse_W = tf.sparse.SparseTensor(self.indices, w, self.shape)
          w_inputs = tf.sparse.sparse_dense_matmul(self.sparse_W, self.inputs)

          # define gradient for this function
          def sparse_weight_grad(upstream_grad):
              '''
              upstream_grad is the gradient computed thus far in the computational graph. 
              The output of this function will be the gradient of the function sparse_weight_multiply 
              times upstream_grad, due to the product rule in differentiation. 

              '''
              # check the shape of upstream_grad
              print("Shape of upstream grad: {}".format(upstream_grad.shape))
              print("Shape of inputs: {}".format(inputs.shape))
              print("Shape of weight: {}".format(self.shape))

              # map entries of input to corresponding locations in the gradient of weights*inputs
              n_out   = upstream_grad.shape[0]
              num_RHS = upstream_grad.shape[1]

              indices_i = range(0,self.num_connections)
              indices_j = self.indices[:, 0]
              indices_k = self.indices[:, 1]
              J_indices = tf.cast(tf.transpose(tf.concat([[indices_i],[indices_j]], 0)), tf.int64)

              grad_weights = []

              for l in range(num_RHS):
                  input_permuted = np.array(self.inputs)[:,l][indices_k]
                  sparse_J = tf.sparse.SparseTensor(J_indices, input_permuted, (self.num_connections, n_out))
                  grad_weights.append(tf.sparse.sparse_dense_matmul(sparse_J, tf.reshape(upstream_grad[:,l], (n_out, 1)) ))

              grad_weights = tf.transpose(tf.squeeze(tf.convert_to_tensor(grad_weights)))
              return grad_weights

          return w_inputs, sparse_weight_grad

However, I get the error:
tensorflow.python.framework.errors_impl.InvalidArgumentError: var and grad do not have the same shape[9632] [9632,638] [Op:ResourceApplyAdam]
I believe this is because my input, w, is a tensor of shape [9632]. However, I want to compute the gradient of W*x for each input x to the layer, of which I have 638 in my training set. Thus, the gradient I return has shape [9632,638], corresponding to a gradient with shape [9632] for each input. This matches the shape of the upstream gradient I am given, which has shape (3836, 638). I definitely want to pass a gradient for each input, but how do I tell tensorflow that that is what I am doing?

maybe @markdaoust might be able to shine some light here

1 Like

Okay, you’ve got a bunch of things going on here. Let me break it down into digestible parts.

That’s incorrect. The gradient is defined:

values = tf.Variable([1.,2.,3.])
x = tf.Variable([[1.,2,3],[4,5,6],[7,8,9]])


with tf.GradientTape() as tape:
  sparse = tf.sparse.SparseTensor(indices = [[0,0],[1,1],[2,2]], values=values, dense_shape=[3,3])
  result = tf.sparse.sparse_dense_matmul(sparse, x)
  loss = tf.reduce_sum(result)

tape.gradient(loss, [values,x])
[<tf.Tensor: shape=(3,), dtype=float32, numpy=array([ 6., 15., 24.], dtype=float32)>,
 <tf.Tensor: shape=(3, 3), dtype=float32, numpy=
 array([[1., 1., 1.],
        [2., 2., 2.],
        [3., 3., 3.]], dtype=float32)>]

The tricky part is that the gradient is sparse too: it only returns a gradient for non-zero entries of the sparse matrix.

I haven’t reviewed this code.

tape.gradient takes the gradient of a scalar. If you pass a non-scalar the result is the gradient of the sum. The result of a gradient is always the same shape as the target.

You’re looking for tape.batch_jacobian

Thank you for your response. I see now that the gradient is defined for sparse_dense_matmul. I identified my problem to instead be from trying to optimize the weights of my layer using these gradients. Here is some sample code that reproduces my bug:

import tensorflow as tf
import scipy.sparse as sparse 
import keras
import numpy as np

# define sparse layer
class Sparse(keras.layers.Layer):
    def __init__(self, 
                 C,                 # The connectivity matrix that defines the connectivity/sparsity of the layer
                 title=''):         # Label for the layer

        super(Sparse, self).__init__()
        self.title = title;

        # determine shape and non-zero indices of C
        self.shape = np.shape(C)
        (I, J, _) = sparse.find(C)
        self.num_connections = I.shape[0]; #num non-zero components of C

        self.indices = tf.cast(tf.transpose(tf.concat([[I],[J]], 0)), tf.int64)

    def build(self, input_shape):
        self.w = self.add_weight(shape = (self.num_connections,), initializer="random_normal", trainable=True, name = 'W'+self.title, dtype = tf.float32)
        self.sparse_w = tf.sparse.SparseTensor(self.indices, self.w, self.shape)

        self.b = self.add_weight(shape=(self.shape[0],), initializer="zeros", trainable=True, name = 'b' + self.title)

        super().build(input_shape)

    def call(self, inputs):

        return tf.math.add(tf.sparse.sparse_dense_matmul(self.sparse_w, inputs), tf.expand_dims(self.b, 1))

class Sparse_NN(tf.keras.Model):
    def __init__(self, C):
        super(Sparse_NN, self).__init__() 

        self.sparse_1 = Sparse(C)

    def call(self, inputs):
        x = self.sparse_1(inputs)
        x = tf.nn.relu(x)
        return x

def grad(model, inputs):
    with tf.GradientTape() as tape:
	    loss = tf.reduce_sum(model(inputs))

    # print which variables are being watched by the gradient tape
    print("Variables being differentiated by gradient tape:")
    for var in tape.watched_variables():
        print(var.name)

    return tape.gradient(loss, [model.sparse_1.w, model.sparse_1.b])

x = tf.Variable([[1.],[2.],[1.]])

# define connectivity of layer
C = sparse.csr_matrix(np.eye(3))

# Define network model
print("Defining model")
model = Sparse_NN(C)

# print description of model architecture
print("Building model")
model.build(input_shape=(3,3))
model.summary() 

# Define optimizer to be used
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

# apply gradients
grads = grad(model, x)
optimizer.apply_gradients(zip(grads, [model.sparse_1.w, model.sparse_1.b]))

I get the error message

WARNING:tensorflow:Gradients do not exist for variables ['sparse/W:0'] when minimizing the loss.

which tells me that the components of my weight tensor are not being trained. Is there a way around this?