Slow initialization of model with dynamic batch size in the C API

I’m using the TensorFlow C API to do inference in the Essentia library, and I’m facing very slow loading times with a new model that we want to support.
As I discovered that the problem can be reproduced with TensorFlow in Python, I though that someone could give me some feedback.

My model is an EfficientNet trained in PyTorch and converted to TensorFlow via the ONNX-TF tool.
I discovered that the model is very slow to load when the batch size is set as a dynamic dimension (convenient to optimize the amount of parallelization according to the available GPU memory).
To investigate this, I created two versions of the model, one with dynamic batch size and one with a fixed batch size of 1.

When I test the models in Python with TF2.7, both are equally fast:

from argparse import ArgumentParser
from time import time

import tensorflow as tf
import numpy as np

def get_rnd_float32(low=-1.0, high=1.0, shape=None):
    output = np.random.uniform(low, high, shape)
    return output.astype(np.float32)


parser = ArgumentParser()
parser.add_argument("model_name")
args = parser.parse_args()
model_name = args.model_name

shape = [1, 128, 96]
x = get_rnd_float32(shape=shape)

start = time()
tf_model = tf.saved_model.load(model_name)
print(f"{model_name} loading time: {time() - start:.1f}s")

start = time()
tf_model_output = tf_model(melspectrogram=x)
print(f"{model_name} inference time: {time() - start:.1f}s")

with effnet_opset11_fixed_axis:

>>> effnet_opset11_fixed_axis loading time: 1.0s
>>> effnet_opset11_fixed_axis inference time: 0.5s

with effnet_opset11_dynamic_axis:

>>> effnet_opset11_dynamic_axis loading time: 1.1s
>>> effnet_opset11_dynamic_axis inference time: 0.5s

Interestingly, using the legacy version (compat.v1) I am able to reproduce the problem I find in the C API:

from argparse import ArgumentParser
from time import time

import tensorflow as tf
import numpy as np

tf.compat.v1.disable_eager_execution()

SHAPE = [1, 128, 96]

def get_rnd_float32(low=-1.0, high=1.0, shape=None):
    output = np.random.uniform(low, high, shape)
    return output.astype(np.float32)


parser = ArgumentParser()
parser.add_argument("model_name")
args = parser.parse_args()

model_name = args.model_name

data = get_rnd_float32(shape=SHAPE)

with tf.Graph().as_default() as g:
    with tf.compat.v1.Session() as sess:
        x = tf.compat.v1.placeholder(tf.float32, shape=SHAPE)
        start = time()
        meta_graph = tf.compat.v1.saved_model.load(sess, ["serve"], model_name)
        print(f"{model_name} meta_graph loading time: {time() - start:.1f}s")
        sig_def = meta_graph.signature_def[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
        input_name = sig_def.inputs['melspectrogram'].name
        output_name = sig_def.outputs['output_0'].name

        start = time()
        sess.run(output_name, feed_dict={input_name: data})
        print(f"{model_name} inference time: {time() - start:.1f}s")

with effnet_opset11_fixed_axis:

>>> effnet_opset11_fixed_axis meta_graph loading time: 0.7s
>>> effnet_opset11_fixed_axis inference time: 0.8s

with effnet_opset11_dynamic_axis:

>>> effnet_opset11_dynamic_axis meta_graph loading time: 310.1s
>>> effnet_opset11_dynamic_axis inference time: 0.9s

Another observation is that a ResNet50 converted following the same pipeline does not experiment slow-initialization issues with or without the dynamic axis feature.
Thus, I suspect that the issue is related to the DepthWiseConv2D layer, (the main difference between the models) combined with a dynamic dimension.

I would really appreciate any comment that helps to understand what is going on, both at the C API level or in the aforementioned Python code.

Thanks,
Pablo.

I’ve run the profiler on the effnet_opset11_dynamic_axis model both in standard and V1 models and got very different results:

TensorFlow 2.7 V1 (compat.v1)

Type #Occurrences Total self-time (us) Total self-time on Host (%)
SplitV 227,222 1,259,431 35%
Sub 227,222 987,115 27.5%
Pack 240,590 755,560 21%
Range 227,222 559,939 15.6%
Conv2D 65 15,354 0.4%
Transpose 153 7,452 0.2%
DepthwiseConv2dNative 16 2,903 0.1%
Pad 17 2,535 0.1%
Sigmoid 66 1,725 0%
Mul 65 988 0%
Mean 17 250 0%
MatMul 1 239 0%
StridedSlice 3 58 0%
AddV2 9 41 0%
Prod 2 35 0%
Placeholder 1 12 0%

TensorFlow 2.7 default behavior

Type #Occurrences Total self-time (us) Total self-time on Host (%)
Transpose 170 37,155 59.6%
Conv2D 64 14,918 23.9%
DepthwiseConv2dNative 16 2,909 4.7%
Pad 17 2,482 4%
Sigmoid 33 1,387 2.2%
AddV2 58 1,219 2%
Mul 47 786 1.3%
SplitV 51 280 0.4%
Sub 51 218 0.3%
Mean 17 218 0.3%
MatMul 1 206 0.3%
Pack 56 179 0.3%
Range 51 120 0.2%
StridedSlice 3 50 0.1%
Prod 2 32 0.1%

Apart from the excessive number of splits, subs, packs, and ranges in the V1 mode, I’ve noticed that the number of convolutions, transposes, etc. do not match, so it seems like the graph is being optimized differently.

I think that, by default, the C API is loading the model in the same way as the TensorFlow 2.7 in v1 mode. Is there a way to make the C API treat the models as TensorFlow 2.7 does by default?