Custom OP / C++: How to write multithreaded CPU kernel?

Hello,
I’m following the documentation here as well as github repository to create my custom operation.

I noticed that the OP is ~2x slower on CPU, when executing code from C++, compared to native Python implementation (without tf.function).

My speculation is that my CPU C++ implementation runs on a single thread, while Python+Tensorflow probably does some optimization under the hood, to make it more efficient.

There is a section about multi-threaded CPU kernels, but upon inspecting work_sharder.h it says:

// DEPRECATED: Prefer threadpool->ParallelFor with SchedulingStrategy (...)

Can anyone provide some (possibly simple) example, with the recommended way to shard the time_two, or similar operation, on the CPU to run on multiple threads?

1 Like

Hi.
How about checking this code?

@jeongukjae Thank you for the example.
Can you perhaps provide a context to the end and begin values? I assume they refer to “per thread” values, so for example, given a tensor of total 1000 elements, do I somehow set these values myself?

1 Like

Ok, I think I eventually figured it out. Leaving what I found in case anyone else stubmles upon this.

Step 1. Provide OpKernelContext to functor

In order to use ParallelFor we must provide OpKernelContext to the functor.
Consider for example time_two.h.

// time_two.h
#ifndef KERNEL_TIME_TWO_H_
#define KERNEL_TIME_TWO_H_

// Include this to allow to use OpKernelContext
#include "tensorflow/core/framework/register_types.h" 

namespace tensorflow {
namespace functor {

template <typename Device, typename T>
struct TimeTwoFunctor {
  void operator()(const Device& d, int size, const T* in, T* out);
};

template <typename Device, typename T>
struct TimeTwoParallelFunctor {
  void operator()(const OpKernelContext* ctx, int size, const T* in, T* out);
};

}  // namespace functor
}  // namespace tensorflow
#endif //KERNEL_TIME_TWO_H_

Step 2. Provide implementation in your kernel.cc

The implementation of TimeTwoParallelFunctor would look as follows:

#include "tensorflow/core/util/work_sharder.h"  // without this threadpool error

template <typename T>
struct TimeTwoParallelFunctor<CPUDevice, T> {
  void operator()(OpKernelContext *ctx, int size, const T* in, T* out) {
    auto thread_pool = ctx->device()->tensorflow_cpu_worker_threads()->workers;

    thread_pool->ParallelFor(
      size, size*1000, // size*1000 is cost and I'm not quite sure how to set it properly
      [&in, &out](int64 start_index, int64 end_index) {
        for (int i=start_index; i<end_index; i++) {
          out[i] = 2* in[i];
        }
    });
  }
};

Make sure to include tensorflow/core/util/work_sharder.h or you will get error when attempting to use thread_pool->ParallelFor.

Step 3. Declare the OP.

The OP is declared in the same way as regular time_two with the difference that we are passing context, instead of the device.

template <typename Device, typename T>
class TimeTwoParallelOp : public OpKernel {
public:
  explicit TimeTwoParallelOp(OpKernelConstruction* context) : OpKernel(context) {}

    void Compute(OpKernelContext* context) override {
    const Tensor& input_tensor = context->input(0);
    Tensor* output_tensor = NULL;
    OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
                                                      &output_tensor));

    OP_REQUIRES(context, input_tensor.NumElements() <= tensorflow::kint32max,
                errors::InvalidArgument("Too many elements in tensor"));
    TimeTwoParallelFunctor<Device, T>()(
      context,
      static_cast<int>(input_tensor.NumElements()),
      input_tensor.flat<T>().data(),
      output_tensor->flat<T>().data());
  }
};

Full example

In case you would like to build it yourself I’m pasting a full content of time_two.cc

#include "time_two.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/util/work_sharder.h"  //threadpool error


namespace tensorflow {

typedef Eigen::ThreadPoolDevice CPUDevice;


REGISTER_OP("TimeTwo")
	.Attr("T: numbertype")
	.Input("in: T")
	.Output("out: T")
	.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
	      c->set_output(0, c->input(0));
	return Status::OK();
});
REGISTER_OP("TimeTwoParallel")
	.Attr("T: numbertype")
	.Input("in: T")
	.Output("out: T")
	.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
	      c->set_output(0, c->input(0));
	return Status::OK();
});

namespace functor {
template <typename T>
struct TimeTwoFunctor<CPUDevice, T> {
  void operator()(const CPUDevice& d, int size, const T* in, T* out) {
    for (int i = 0; i < size; ++i) {
      out[i] = 2 * in[i];
    }
  }
};

template <typename T>
struct TimeTwoParallelFunctor<CPUDevice, T> {
  void operator()(OpKernelContext *ctx, int size, const T* in, T* out) {
    auto thread_pool = ctx->device()->tensorflow_cpu_worker_threads()->workers;

    thread_pool->ParallelFor(
      size, size*1000, // size*1000 is cost and I'm not quite sure how to set it properly
      [&in, &out](int64 start_index, int64 end_index) {
        for (int i=start_index; i<end_index; i++) {
          out[i] = 2* in[i];
        }
    });
  }
};

// Implement kernel
template <typename Device, typename T>
class TimeTwoOp : public OpKernel {
public:
  explicit TimeTwoOp(OpKernelConstruction* context) : OpKernel(context) {}

    void Compute(OpKernelContext* context) override {
    const Tensor& input_tensor = context->input(0);
    Tensor* output_tensor = NULL;
    OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
                                                      &output_tensor));

    OP_REQUIRES(context, input_tensor.NumElements() <= tensorflow::kint32max,
                errors::InvalidArgument("Too many elements in tensor"));
    TimeTwoFunctor<Device, T>()(
      context->eigen_device<Device>(),
      static_cast<int>(input_tensor.NumElements()),
      input_tensor.flat<T>().data(),
      output_tensor->flat<T>().data());
  }
};

template <typename Device, typename T>
class TimeTwoParallelOp : public OpKernel {
public:
  explicit TimeTwoParallelOp(OpKernelConstruction* context) : OpKernel(context) {}

    void Compute(OpKernelContext* context) override {
    const Tensor& input_tensor = context->input(0);
    Tensor* output_tensor = NULL;
    OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
                                                      &output_tensor));

    OP_REQUIRES(context, input_tensor.NumElements() <= tensorflow::kint32max,
                errors::InvalidArgument("Too many elements in tensor"));
    TimeTwoParallelFunctor<Device, T>()(
      context,
      static_cast<int>(input_tensor.NumElements()),
      input_tensor.flat<T>().data(),
      output_tensor->flat<T>().data());
  }
};

// Register the CPU kernels.
#define REGISTER_CPU(T)                                          \
  REGISTER_KERNEL_BUILDER(                                       \
      Name("TimeTwo").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
      TimeTwoOp<CPUDevice, T>);                                  \
  REGISTER_KERNEL_BUILDER(                                       \
      Name("TimeTwoParallel").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
      TimeTwoParallelOp<CPUDevice, T>);
REGISTER_CPU(float);
REGISTER_CPU(int32);
} // end namespace functor
} // end namespace tensorflow

Step 4. Benchmark

Let us benchmark the two ops. Consider simple benchmark script:

import time
import tensorflow as tf

time_two = tf.load_op_library("./time_two.so")


def main():
    SHAPE = (5, 224, 224, 3)
    rng = tf.random.Generator.from_seed(1234)

    # Default OP
    default_results = _benchmark(time_two.time_two, rng, SHAPE)
    print(default_results)

    # Parallel
    parallel_results = _benchmark(time_two.time_two_parallel, rng, SHAPE)
    print(parallel_results)

def _benchmark(func, rng, shape):
    for _ in range(10):
        func(rng.uniform(shape=shape))

    results = []
    for _ in range(100):
        noise = rng.uniform(shape)
        start = time.perf_counter()
        func(noise)
        stop = time.perf_counter()
        results.append(stop-start)

    return tf.reduce_mean(results)


if __name__ == "__main__":
    main()

The results for this particular OP are very similar:

# Single threaded
tf.Tensor(0.0017192168, shape=(), dtype=float32)
# Multi threaded
tf.Tensor(0.0015856741, shape=(), dtype=float32)

However for the problem I was originally trying to solve I got 4x speedup in C++ by using ParallelFor.

Thank’s again @jeongukjae for providing example which eventually helped me get this right.

1 Like

If I try to compile this code with the tf comtus op compilation method, there will be problems with the input and output of the shard function

compile command

TF_CFLAGS="-I/usr/local/cuda/targets/x86_64-linux/include -I/usr/lib64/python2.7/site-packages/tensorflow_core/include/ -I/usr/lib/python2.7/site-packages/tensorflow_core/include -D_GLIBCXX_USE_CXX11_ABI=0"
TF_LFLAGS="-L/usr/lib64/python2.7/site-packages/tensorflow_core -l:libtensorflow_framework.so.1"
g++ -std=c++11 -shared time_two.h time_two.cc -o pad_string.so -fPIC ${TF_CFLAGS[@]} ${TF_LFLAGS[@]} -O3

There will be uninitialized data in the two parameters of the shard function. For example, my concurrent total is 2, the cost is 1,000,000, and the number of threads is 2. In theory, the input from the shard should be [0, 1), [1, 2 ), in fact, two super large numbers are transmitted

Do you mean there is a problem with compiling times_two or your other custom op?
Can you share more details on your operation? What does it do?