TFLite model interpreter error: RuntimeError: tensorflow/lite/kernels/reshape.cc:92 num_input_elements != num_output_elements (1 != 16000)Node number 8 (RESHAPE) failed to prepare. Failed to apply the default TensorFlow Lite delegate indexed at 0

Hi all,

I have a tensorflow-lite model converted from a .h5 keras model. The purpose of this model is to perform bandwidth extension on an input audio file (8KHz file extended to 16KHz by the model).
The keras model works fine, but the tflite version generates this error at the line: interpreter.allocate_tensors()

"RuntimeError: tensorflow/lite/kernels/reshape.cc:92 num_input_elements != num_output_elements (1 != 16000)Node number 8 (RESHAPE) failed to prepare.Failed to apply the default TensorFlow Lite delegate indexed at 0."

NOTE: The .h5 model itself was created from a trained Pytorch model (GitHub - zeroone-universe/RealTimeBWE: Unofficial Pytorch Lightning Implementation of "Real-time Speech Frequency Bandwidth Extension"), and the conversion was carried out using this library: GitHub - AlexanderLutsenko/nobuco: Pytorch to Keras/Tensorflow conversion made intuitive

The input and output shapes of the tflite model are valid and are as expected. Please find attached the scripts used for conversion and inference:

import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""

import nobuco
from nobuco import ChannelOrder, ChannelOrderingStrategy, TraceLevel
from nobuco.layers.weight import WeightLayer

import torch
from pytorch_lightning import Trainer
import tensorflow as tf
from tensorflow.lite.python.lite import TFLiteConverter
import keras

from train import RTBWETrain
import yaml
config = yaml.load(open("C:/Downloads/config.yaml", 'r'), Loader=yaml.FullLoader)

trainer = Trainer()
model = RTBWETrain(config=config)
model.trainer = trainer
state_dict = torch.load("C:/Downloads/pytorch_model.pt")
model.generator.load_state_dict(state_dict)
model.generator.eval()

input_shape = (1, 1, 8000)
dummy_input = torch.randn(input_shape)
traced_model = torch.jit.trace(model, dummy_input)
# Pass the tensor to the model
input_tensor = torch.randn(1, 1, 8000)
output = traced_model(input_tensor)

# Print the size of the output
print(output.shape)

# Conversion to TF

keras_model = nobuco.pytorch_to_keras(
    traced_model,
    args=[dummy_input],
    inputs_channel_order=ChannelOrder.TENSORFLOW,
    outputs_channel_order=ChannelOrder.TENSORFLOW,
)

model_path = 'C:/Downloads/rtmlbwe_test2'

keras_model.save(model_path + '.h5')
print('Model saved')

custom_objects = {'WeightLayer': WeightLayer}

keras_model_restored = keras.models.load_model(model_path + '.h5', custom_objects=custom_objects)
print('Model loaded')

converter = TFLiteConverter.from_keras_model_file(model_path + '.h5', custom_objects=custom_objects)
converter.target_ops = [tf.lite.OpsSet.SELECT_TF_OPS, tf.lite.OpsSet.TFLITE_BUILTINS]
tflite_model = converter.convert()
with open(model_path + '.tflite', 'wb') as f:
    f.write(tflite_model)

I also tried converting using a concrete function:

# pytorch to keras conversion
keras_model = nobuco.pytorch_to_keras(
    traced_model,
    args=[dummy_input], kwargs=None,
    inputs_channel_order=ChannelOrder.TENSORFLOW,
    outputs_channel_order=ChannelOrder.TENSORFLOW,
)
tf.saved_model.save(keras_model, saved_model_path)

# Load the saved model (recommended format by tf docs)
loaded_model = tf.saved_model.load(saved_model_path)

# Create concrete function with defined input shape
tf_input_shape = (1, 8000, 1)
concrete_func = loaded_model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
concrete_func.inputs[0].set_shape(tf_input_shape)

# Convert to tflite
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
tflite_model = converter.convert()
with open(os.path.join(model_dir, f"{model_name}.tflite"), "wb") as f:
    f.write(tflite_model)

This is the inference script which generates the error at the line interpreter.allocate_tensors():

import tensorflow as tf
import numpy as np
import librosa
import soundfile
import os
import torch

# Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path="C:/Downloads/rtmlbwe_test2.tflite")


# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

print(input_details)
print(output_details)

interpreter.allocate_tensors()
# Define the path to your audio file as a variable
audio_file_path = "C:/Downloads/audio.wav"

# Load the entire audio file
audio, sr = librosa.load(audio_file_path)

# Split the audio into chunks of 160 samples
chunks = np.array_split(audio, len(audio) // 8000)

# Initialize an empty list to store the processed outputs
processed_outputs = []

# Iterate over the chunks
for chunk in chunks:
    # Reshape the chunk to match the model's input shape
    chunk = np.resize(chunk, (1, 8000, 1))

    # Run the inference
    
    interpreter.set_tensor(input_details[0]['index'], chunk)
    
    interpreter.invoke()
    output_data = interpreter.get_tensor(output_details[0]['index'])

    # Append the output to the list of processed outputs
    processed_outputs.append(output_data[0])

# Concatenate all the processed outputs to create the final full audio file
final_output = np.concatenate(processed_outputs)


# Define the path for the processed output file
base_name = os.path.basename(audio_file_path)  # Get the original file name
name_without_ext = os.path.splitext(base_name)[0]  # Remove the extension
new_name = name_without_ext + '_BWE.wav'  # Append '_BWE' to the original file name

# Save the processed output
output_file_path = os.path.join('C:/Downloads', new_name)
soundfile.write(output_file_path, final_output, 16000)

Here are the first 15 operations of the tflite model printed out, where OP#7 STRIDED SLICE generates a large negative value:

Op#0 RESHAPE(T#0, T#221[-1, 8000]) -> [T#235]
  Op#1 PAD(T#235, T#228[0, 0, 7, 8]) -> [T#236]
  Op#2 RESHAPE(T#236, T#129[1, 1, 8015]) -> [T#237]
  Op#3 STRIDED_SLICE(T#237, T#130[0, 0, 0], T#131[0, 0, 8015], T#132[1, 1, 1]) -> [T#238]
  Op#4 RESHAPE(T#238, T#85[1, 1, 8015, 1]) -> [T#239]
  Op#5 DEPTHWISE_CONV_2D(T#239, T#134, T#133) -> [T#240]
  Op#6 RESHAPE(T#240, T#220[1, -1]) -> [T#241]
  Op#7 STRIDED_SLICE(T#241, T#135[0, -1922558064], T#135[0, -1922558064], T#136[1, -1922558064]) -> [T#242]
  Op#8 RESHAPE(T#242, T#219[1, 1, 16000]) -> [T#243]
  Op#9 PAD(T#243, T#227[0, 0, 0, 0, 6, ...]) -> [T#244]
  Op#10 RESHAPE(T#244, T#86[1, 1, 16006, 1]) -> [T#245]
  Op#11 DEPTHWISE_CONV_2D(T#245, T#137, T#87) -> [T#246]
  Op#12 RESHAPE(T#246, T#88[1, 16000, 8]) -> [T#247]
  Op#13 ELU(T#247) -> [T#248]
  Op#14 PAD(T#248, T#225[0, 0, 2, 0, 0, ...]) -> [T#249]
  Op#15 RESHAPE(T#249, T#89[1, 1, 16002, 8]) -> [T#250]

It seems like the data around that layer is affected in a way that’s causing the error. I appreciate any assistance to help resolve this issue. Thank you for your time!

My Tensorflow version is 2.14.0