How to calculate FLOPs of transformer in tensorflow?

I know that

    flops = tf.profiler.profile(graph, options=tf.profiler.ProfileOptionBuilder.float_operation())

can calculate the FLOPs.
But where can I find the graph of transformer?
Please help me.

There Is a quite long thread for this in TF 2.x:

2 Likes