TFLITE does not support LSTM?

1. When converting LSTM to tflite, the following WARNING message appears :thinking:.

WARNING:absl:Found untraced functions such as lstm_cell_16_layer_call_and_return_conditional_losses, lstm_cell_16_layer_call_fn, lstm_cell_17_layer_call_and_return_conditional_losses, lstm_cell_17_layer_call_fn, lstm_cell_18_layer_call_and_return_conditional_losses while saving (showing 5 of 40). These functions will not be directly callable after loading.

2. first, ignore the message and measure test accuracy on computer with the tflite model

it shows 14% accuracy…! Before the conversion, it was 94%.

3. Here’s the code to convert the model

h5_model = load_model('//content//multi_lstm.h5')
converter_h5 = tf.lite.TFLiteConverter.from_keras_model(h5_model)
converter_h5.optimizations = [tf.lite.Optimize.DEFAULT]
converter_h5.experimental_new_converter=True
converter_h5.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, 
                                          tf.lite.OpsSet.SELECT_TF_OPS]
tflite_model = converter_h5.convert()

I don’t know why this is happening. Any help would be greatly appreciated. :face_with_spiral_eyes:

+ My model is Multi LSTM model with multiple inputs.

Hi @philip, If possible could you please provide the standalone code to reproduce the issue. Thank You.

@Kiran_Sai_Ramineni Thanks for commenting!

The data has 12 features and the output is seven classes.
the model is a multi-input LSTM with 4 inputs of (batch size, 300, 3).

x_test = pd.read_csv('.\\data\\x_test.csv')
y_test = pd.read_csv('.\\data\\y_test.csv')
#%%
def tflite_test(model_path, x_test, y_test):
    right_answer = 0

    interpreter_multi = tf.lite.Interpreter(model_path=model_path)

    interpreter_multi.allocate_tensors()

    input_details = interpreter_multi.get_input_details()
    output_details = interpreter_multi.get_output_details()
    
    output = []
    x_test = x_test.to_numpy(dtype=np.float32)
    x_test = x_test.reshape(-1,300,12)

    for i, input_data in enumerate(x_test):
        
        input_data = input_data.reshape(1, 300, 12)
        
        x1 = input_data[:,:,0:3]
        x2 = input_data[:,:,3:6]
        x3 = input_data[:,:,6:9]
        x4 = input_data[:,:,9:12]
        
        interpreter_multi.set_tensor(input_details[0]['index'], x1)
        interpreter_multi.set_tensor(input_details[1]['index'], x2)
        interpreter_multi.set_tensor(input_details[2]['index'], x3)
        interpreter_multi.set_tensor(input_details[3]['index'], x4)
        
        interpreter_multi.invoke()

        output_data = interpreter_multi.get_tensor(output_details[0]['index'])
        output.append(output_data)
        
    out = np.array(output)

    predict_label = []
    for i in out:
        predict_label.append(i.argmax())
    for index in range(len(predict_label)):
        if predict_label[index] == y_test.to_numpy()[index]:
            right_answer += 1

    accuracy = (right_answer / len(predict_label)) * 100
    print(f'accuracy : {accuracy:.3f}%')

Hi @philip, If possible could you please provide the model that was converted to tflite and also sample data from the dataset you are using. so that i can reproduce and understand the issue better. Thank You.

Hi @Kiran_Sai_Ramineni, sorry for the late reply.

if you possible, Send me your email address and I’ll share it with you on Google Drive.

Thank you.

Hi @philip, You can make it public and share the link here so that i can access it. Thank You.

Hi, @Kiran_Sai_Ramineni , I have some good news!
For this problem, I visualized the model structure, I’ve found the cause.
I noticed that the order of the inputs was reversed during the conversion, and once I fixed that, it worked fine.

but, Compared to other models (multi CNN etc…), LSTM is exceptionally slow, do you know why?

Hi @philip, LSTM networks tend to be computationally more expensive compared to CNNs. LSTMs will have more parameters compared to simpler models like feedforward neural networks due to this they require more computations during both training and inference. Thank You.