Does this do what you’re looking for?
import tensorflow as tf
fun = lambda x: tf.nn.softmax(x)
tf_fun = tf.function(fun)
graph = tf_fun.get_concrete_function(tf.constant([1.0])).graph
isinstance(concrete_fun.graph, tf.Graph) # True
graph.get_operations() # returns a list of graph operations
Notice the model’s forward pass will need to be a tf.function
and you’ll need to know at least the input shapes ahead of time. If the model is a Keras model, here’s a somewhat hacky way to do that:
model = tf.keras.models.Sequential([
tf.keras.layers.Dense(1, activation='softmax', input_shape=(1,))
])
model.predict(tf.constant([1.0])) # model.predict_function is only populated after 1 call to predict, you can also do the same thing with model.train_function
graph = model.predict_function.get_concrete_function(iter([tf.constant([1.0])])).graph # The concrete function takes an iterator
isinstance(graph, tf.Graph) # True
graph.get_operations()
For more info: Introduction to graphs and tf.function | TensorFlow Core