Custom convolutional op in TF Lite

Hi, I need to implement a custom convolution layer which is not supported by Tensorflow and TF Lite, so I tried to define it by using the tutorial to have a TF operator for a custom op and the guide to have a custom op supported by TF Lite. However, when I try to convert the operator with TF Lite converter, I get this error:

 Traceback (most recent call last):
 File "es.py", line 39, in <module>
 converter =  
 tf.lite.TFLiteConverter.from_concrete_functions([tf.function(convol1d).get_concrete_function(inp)])
 File "/home/em/venv/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py",    
 line 1299, in get_concrete_function
 concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
 File "/home/em/venv/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py",    
 line 1205, in _get_concrete_function_garbage_collected
 self._initialize(args, kwargs, add_initializers_to=initializers)
 File "/home/em/venv/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py",     
 line 725, in _initialize 
 self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-   
 access
 File "/home/em/venv/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line  
 2969, in _get_concrete_function_internal_garbage_collected
 graph_function, _ = self._maybe_define_function(args, kwargs)
 File "/home/em/venv/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line    
 3314, in _maybe_define_function
 self._function_spec.canonicalize_function_inputs(*args, **kwargs)
 File "/home/em/venv/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line     
 2697, in canonicalize_function_inputs
 inputs, flat_inputs, filtered_flat_inputs = _convert_numpy_inputs(inputs)
 File "/home/em/venv/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line  
 2753, in _convert_numpy_inputs
 a = _as_ndarray(value)
 File "/home/em/venv/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line    
 2711, in _as_ndarray
 return value.__array__()
 File "/home/em/venv/lib/python3.8/site-packages/tensorflow/python/keras/engine   
 /keras_tensor.py", line 273, in __array__
 raise TypeError(
 TypeError: Cannot convert a symbolic Keras input/output to a numpy array. This error may    
 indicate that you're trying to pass a symbolic value to a NumPy call, which is not supported. Or,   
 you may be trying to pass Keras symbolic inputs/outputs to a TF API that does not register 
 dispatching, preventing Keras from automatically converting the API call to a lambda layer in the 
 Functional Model.

The code is like this:

import tensorflow as tf
tf.config.run_functions_eagerly(True)
from keras.datasets import mnist
from keras.models import Model
from keras.layers import add,Input,Activation,Flatten,Dense

def convol(inp):
    conv_module = tf.load_op_library('./conv.so')
    x = conv_module.conv(inp, name="Conv")
    return x

def read_mnist(path):
    (train_x,train_y), (test_x,test_y)=mnist.load_data()
    return train_x,train_y,test_x,test_y

def tcn(train_x,train_y,test_x,test_y):
      inp=Input(shape=(28,28))
      x = convol(inp)
      x=Flatten()(x)
      x=Dense(10,activation='softmax')(x)
      model=Model(inputs=inp,outputs=x)
      model.summary()
      model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics= 
      ['accuracy'])
       model.fit(train_x,train_y,batch_size=100,epochs=10,validation_data=(test_x,test_y))
       pred=model.evaluate(test_x,test_y,batch_size=100)
       print('test_loss:',pred[0],'- test_acc:',pred[1])

 train_x,train_y,test_x,test_y=read_mnist('MNIST_data')
 tcn(train_x,train_y,test_x,test_y)
 tflite_model_name = 'net'
 inp=Input(shape=(28,28))
 converter =   
 tf.lite.TFLiteConverter.from_concrete_functions([tf.function(convol).get_concrete_function(inp)])
 converter.allow_custom_ops = True
 tflite_model = converter.convert()
 open(tflite_model_name + '.tflite', 'wb').write(tflite_model)

Hi @Gianni_Rossi

So does the model train properly but not converted? Is there a colab notebook to reproduce?

Hi, thanks for your reply. I can train the model, but I can’t convert it. Unfortunately I have problems using the .so file obtained from the TF custom conv operator on Google Colab.

@George_Soloupis

Since I need to have a dilated and causal convolutional 1D op which can be supported by TF Lite (and we have a Conv1D supported by TF), I think I might as well not use my custom TF conv op above, and instead use the class Conv1D. Following the tutorial Custom operators  |  TensorFlow Lite I tried to implement this very simple conv1D:

#include "tensorflow/lite/kernels/conv_1d.h"

#include <math.h>
#include <stddef.h>
#include <stdint.h>
#include <vector>
#include "tensorflow/lite/c/common.h"

#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"

namespace tflite {
namespace ops {
namespace custom {
namespace conv_1d {

const int dim = 5;
int dim_in;  
int dim_out;  
int dim_k = 3;
float copy[dim];

constexpr float kernel[3] = {1.2,2.0,4.2};
constexpr int dilation = 2;

TfLiteStatus Conv1dPrepare(TfLiteContext* context, TfLiteNode* node) {

TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);

const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);

int num_dims = NumDimensions(input);

TfLiteIntArray* output_size = TfLiteIntArrayCreate(num_dims);
for (int i=0; i<num_dims; ++i) {
   output_size->data[i] = input->dims->data[i];
}

return context->ResizeTensor(context, output, output_size);
}

TfLiteStatus Conv1dEval(TfLiteContext* context, TfLiteNode* node) {

 const TfLiteTensor* input = GetInput(context, node,0);
 TfLiteTensor* output = GetOutput(context, node,0);

 float* input_data = input->data.f;
 float* output_data = output->data.f;

 if (output->dims->data[0] > 1)
  dim_out = output->dims->data[0];

 else dim_out = output->dims->data[1];

 if (input->dims->data[0] > 1)
   dim_in = input->dims->data[0];

else dim_in = input->dims->data[1];

float copy0[4+dim_in];

for (int i=0; i<4; i++) {
copy0[i] = 0;
}

for (int i=0; i<dim_in; i++) {
 copy0[i+4] = input_data[i];
}

for (int i=0; i<dim_out; i++) {
  for (int m=0; m<dim; m++) {
   copy[m] = copy0[m+i];
  }
  for (int j=0; j<dim_k; j++) {
   output_data[i] = output_data[i] + copy[j*dilation]*kernel[j];
   }

 }
return kTfLiteOk;
}


}  // namespace conv_1d

TfLiteRegistration* Register_CONV_1D() {
static TfLiteRegistration r = {nullptr, nullptr, conv_1d::Conv1dPrepare, conv_1d::Conv1dEval};
return &r;
}

}  // namespace custom
}  // namespace ops
}  // namespace tflite

However, I always use a fixed kernel which is not trainable. In order to access weights created by the Class Conv1D and other inputs, how can I do? Then I’d like to use this code in Python to register the op:

import tensorflow as tf
tf.config.run_functions_eagerly(True)
from keras.datasets import mnist
from keras.models import Model
from keras.layers import Conv1D,Activation,Input,Flatten,Dense

num_filters = 32
kern_size = 3

@tf.function
def convol(inp, num_filters, kern_size):
   r = Conv1D(filters = num_filters, kernel_size = kern_size, name="cd1")(inp)
   return x

def read_mnist(path):
  (train_x,train_y), (test_x,test_y)=mnist.load_data()
  return train_x,train_y,test_x,test_y

def tcn(train_x,train_y,test_x,test_y):
     inp=Input(shape=(28,28))
     x = convol(inp,num_filters,kern_size)
     x=Flatten()(x)
     x=Dense(10,activation='softmax')(x)
     model=Model(inputs=inp,outputs=x)
     model.summary()
     model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics= 
     ['accuracy'])
     model.fit(train_x,train_y,batch_size=100,epochs=10,validation_data=(test_x,test_y))
     pred=model.evaluate(test_x,test_y,batch_size=100)
     print('test_loss:',pred[0],'- test_acc:',pred[1])

train_x,train_y,test_x,test_y=read_mnist('MNIST_data')
tcn(train_x,train_y,test_x,test_y)
tflite_model_name = 'net'
inp=Input(shape=(28,28))
converter =  tf.lite.TFLiteConverter.from_concrete_functions([convol.get_concrete_function(inp,num_filters,kern_size)])
converter.allow_custom_ops = True
tflite_model = converter.convert()
open(tflite_model_name + '.tflite', 'wb').write(tflite_model)