How to Add Operators for new Library or Device - PaddlePaddle/Paddle GitHub Wiki

How to Add Operators for new Library or Device

Step 1. Write Operator and it's Kernels.

Please refer to this document: new_op

The mechanism that an Operator choose a kernel is in this document. switch_kernel

Step 2. Register the Operator and it's kernels to framework

After writing the components described in new_op. We should register the Operator and its kernels into the framework.

When register Operator kernels, we should consider the type of the kernels.

We use OpKernelType to describe the type of an OperatorKernel.

struct OpKernelType {
  platform::Place place_;
  platform::Library library_;
  proto::DataType data_type_;
  framework::Layout layout_;
};

The meaning of each filed for an OpKernelType is in this document: operator_kernel_type

Take conv_2d as an example:

REGISTER_OP(conv2d, ops::ConvOp, ops::Conv2DOpMaker, conv2d_grad,
           ops::ConvOpGrad);
  • Secondly, Register Operator kernels with their library/place/data_type. Refs

     REGISTER_OP_KERNEL(conv2d, CUDNN, ::paddle::platform::CUDAPlace,
                        paddle::operators::CUDNNConvOpKernel<float>,
                        paddle::operators::CUDNNConvOpKernel<double>);
     REGISTER_OP_KERNEL(conv2d_grad, CUDNN, ::paddle::platform::CUDAPlace,
                        paddle::operators::CUDNNConvGradOpKernel<float>,
                        paddle::operators::CUDNNConvGradOpKernel<double>);
     
     REGISTER_OP_KERNEL(conv3d, CUDNN, ::paddle::platform::CUDAPlace,
                        paddle::operators::CUDNNConvOpKernel<float>,
                        paddle::operators::CUDNNConvOpKernel<double>);
     REGISTER_OP_KERNEL(conv3d_grad, CUDNN, ::paddle::platform::CUDAPlace,
                        paddle::operators::CUDNNConvGradOpKernel<float>,
                        paddle::operators::CUDNNConvGradOpKernel<double>);

In the code above:

  • conv2d is the type/name of the operator
  • CUDNN is library
  • ::paddle::platform::CUDAPlace is place
  • template parameter float/double on CUDNNConvOpKernel<T> is data_type.
⚠️ **GitHub.com Fallback** ⚠️