Freezing with jit_compile=True

Hi all, I have been working on a model which makes liberal use of @tf.function(jit_compile=True). However, when I try to export the model as a Frozen Graph, I am hitting some issues with missing output tensor shapes. However, removing jit_compile=True sidesteps the issue.

I have distilled this down into the following reproducer:

import tensorflow as tf
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2

def jit_compiled_fn(x, y):
    x_squared = tf.square(x)
    result = x_squared + y
    return x_squared, result

def not_jit_compiled_fn(x, y):
    return jit_compiled_fn(x, y)

concrete_fn = not_jit_compiled_fn.get_concrete_function(
    tf.TensorSpec(shape=[1, 2], dtype=tf.float32),
    tf.TensorSpec(shape=[1, 2], dtype=tf.float32))

print("Before Freezing:")

frozen_fn = convert_variables_to_constants_v2(concrete_fn)

print("After Freezing:")
for tensor in frozen_fn.graph.get_operations():
    if tensor.type == 'Identity':
        print(f"{}: {tensor.outputs[0].shape}")

Which prints the following:

Before Freezing:
(<tf.Tensor 'Identity:0' shape=(1, 2) dtype=float32>, <tf.Tensor 'Identity_1:0' shape=(1, 2) dtype=float32>)

After Freezing:
Identity: <unknown>
Identity_1: <unknown>

However, removing the jit_compile kwarg prints

Before Freezing:
(<tf.Tensor 'Identity:0' shape=(1, 2) dtype=float32>, <tf.Tensor 'Identity_1:0' shape=(1, 2) dtype=float32>)

After Freezing:
Func/PartitionedCall/input/_0: (1, 2)
PartitionedCall/Identity: (1, 2)
Func/PartitionedCall/output/_2: (1, 2)
Identity: (1, 2)
Func/PartitionedCall/input/_1: (1, 2)
PartitionedCall/Identity_1: (1, 2)
Func/PartitionedCall/output/_3: (1, 2)
Identity_1: (1, 2)

It is also worth noting that adding jit_compile=True to not_jit_compiled_fn also obscures the output shapes. However, freezing the concrete function of jit_compiled_fn, does not.

I was hoping to sanity check my approach before considering it a bug; it’s far more likely that I’m missing some autograph or XLA nuance.

EDIT: I had forgotten that since TF 2.0, frozen graphs aren’t supported/are deprecated, so perhaps it’s not surprising that XLA compilation of functions causes issues with freezing, which is a throwback to TF1 sessions.

1 Like