# Strassen's Matrix Multiplication on TensorFlow

I am trying to do Strassen’s matrix multiplication on TensorFlow. I am trying to perform a graph execution using the following code

``````import tensorflow as tf

@tf.function  # The decorator converts `split_tf` into a `Function`.
def split_tf(matrix):
shape = tf.shape(matrix)
shape2 = tf.cast(tf.divide(shape, 2), dtype=tf.int32)
return tf.slice(A, [0, 0], shape2), tf.slice(A, [0, shape2[1]], shape2),
tf.slice(A, [shape2[0], 0], shape2),
tf.slice(A, [shape2[0], shape2[1]], shape2)

@tf.function  # The decorator converts `strassen_tf` into a `Function`.
def strassen_tf(x, y):
# Base case when size of matrices is 1x1
if tf.rank(x) == 1:
return tf.math.multiply(x, y)

# Splitting the matrices into quadrants. This will be done recursively
# until the base case is reached.

a, b, c, d = split_tf(x)
e, f, g, h = split_tf(y)

# Computing the 7 products, recursively (p1, p2...p7)
p1 = strassen_tf(a, tf.math.subtract(f, h))
p4 = strassen_tf(d, tf.math.subtract(g, e))
p6 = strassen_tf(tf.math.subtract(b, d), tf.math.add(g, h))
p7 = strassen_tf(tf.math.subtract(a, c), tf.math.add(e, f))

# Computing the values of the 4 quadrants of the final matrix c
c22 = tf.math.subtract(tf.math.subtract(tf.math.add(p1, p5), p3), p7)

# Combining the 4 quadrants into a single matrix by stacking horizontally and vertically.
c = tf.concat([tf.concat([C11, C12], axis=1), tf.concat([C21, C22], axis=1)], axis=0)

return c
``````

Now when I use the above code to multiply two tensors as below:

``````n = 2;
A = tf.random.normal([n, n], 0, 1, dtype=tf.float64)
B = tf.random.normal([n, n], 0, 1, dtype=tf.float64)
D = strassen_tf(A, B)
``````

the code stuck at the last statement.

First, I would 100% advise against using recursion in a `tf.function`. I don’t think `tf.function` will play well with recursion. `tf.function` needs to know the entire graph at compile-time and so it will (I think) trace through the recursion until it has a complete graph. With large matrices this will result in REALLY big graphs that are difficult to compile. You should opt for `tf` control-flow instead e.g. `tf.scan`, `tf.while_loop`, etc.

Second, if you add some debug prints and inspect the rank, you’ll see the rank always remains `2`. Notice your comment for your base case says the base case is a `1 x 1` matrix, which `tf.rank` will always evaluate to `2`. I think instead what you’re looking for is `tf.size(x) == 1` OR `tf.linalg.matrix_rank(x) == 1`. `tf.rank` returns the number of dimensions in the Tensor, whereas `tf.linalg.matrix_rank` returns the rank in the sense you are trying to use it. Changing that and commenting out the `tf.function` decorators evaluates the function correctly. If you want to use `tf.function`, try to rework the recursion to use native TensorFlow control-flow.

1 Like

Thanks a lot, @Sean_Moriarity for the detailed reply. My mistake about the rank. Although I have two questions.

1. As I understand tf.function uses Graphs (operations+tensors) for function’s computations. Now If I don’t use tf.function, how is the graph generated. And if it is not generated, how does the code run faster because internally graphs optimize by pipelining parallel operations together. (This is my understanding. I may be wrong). Do tf.scan and tf.while_loop also create Graphs.

2. The code is running fine by removing decorators i.e. eagerly. How can I run the code non-eagerly or using graphs?

Yes recursion is not supported. We had already a ticket on this specific aspect of your issue:

1 Like

You are correct `tf.function` builds a graph, `tf.scan` and `tf.while_loop` are native TensorFlow operations that will be included in the graph so you can use them within `tf.function`. Here’s an internal discussion on TF control flow: Inside TensorFlow: Control Flow - YouTube

1 Like

Okay. Got it. That means the implementations of a recursive linear algorithm like Strassen’s and other matrix inversion technique have to wait until its release.

Thanks @Bhack for the pointer.