paddlelite - YingkunZhou/EdgeTransformerBench GitHub Wiki
TODO:
- try to use arm82
- int8 量化?fp16?
- openmp(default) ✅ vs thread pool ❌
- 这些个打印信息还挺有意思,怎么把SVE2用起来,还是他能够自动识别?!
log
[I 9/16 0:22: 8.872 ...ork/Paddle-Lite/lite/core/device_info.cc:1351 Setup] Total memory: 31322864KB
[I 9/16 0:22: 8.872 ...ork/Paddle-Lite/lite/core/device_info.cc:1352 Setup] SVE2 support: 0
[I 9/16 0:22: 8.872 ...ork/Paddle-Lite/lite/core/device_info.cc:1353 Setup] SVE2 f32mm support: 0
[I 9/16 0:22: 8.872 ...ork/Paddle-Lite/lite/core/device_info.cc:1354 Setup] SVE2 i8mm support: 0
git clone https://github.com/PaddlePaddle/Paddle-Lite.git #--depth=1
# e241420f813bd91f5164f0d9ee0bc44166c0a172
关于编译脚本的定制化修改
diff --git a/lite/tools/build_android.sh b/lite/tools/build_android.sh
index f3dceb9..cf8efac 100755
--- a/lite/tools/build_android.sh
+++ b/lite/tools/build_android.sh
@@ -76,7 +76,7 @@ WITH_CONVERT_TO_SSA=ON
# use Arm DNN library instead of built-in math library, defaults to OFF.
WITH_ARM_DNN_LIBRARY=OFF
# num of threads used during compiling..
-readonly NUM_PROC=${LITE_BUILD_THREADS:-4}
+readonly NUM_PROC=32
@@ -206,9 +206,6 @@ function make_tiny_publish_so {
# Step1. Create directory for compiling.
build_dir=$workspace/build.lite.android.$ARCH.$TOOLCHAIN
- if [ -d $build_dir ]; then
- rm -rf $build_dir
- fi
mkdir -p $build_dir
cd $build_dir
diff --git a/lite/tools/build_linux.sh b/lite/tools/build_linux.sh
index ace7a8b..d3df143 100755
--- a/lite/tools/build_linux.sh
+++ b/lite/tools/build_linux.sh
@@ -100,7 +100,7 @@ WITH_BENCHMARK=OFF
# use Arm DNN library instead of built-in math library, defaults to OFF.
WITH_ARM_DNN_LIBRARY=OFF
# num of threads used during compiling..
-readonly NUM_PROC=${LITE_BUILD_THREADS:-4}
+readonly NUM_PROC=32
@@ -344,9 +344,6 @@ function make_publish_so {
build_dir=${build_dir}.kunlunxin_xpu
fi
- if [ -d $build_dir ]; then
- rm -rf $build_dir
- fi
mkdir -p $build_dir
cd $build_dir
in order to build with clang successfully, need
diff --git a/lite/api/paddle_place.h b/lite/api/paddle_place.h
index c5757b8..2c87f41 100644
--- a/lite/api/paddle_place.h
+++ b/lite/api/paddle_place.h
@@ -15,6 +15,7 @@
#pragma once
#include <set>
#include <string>
+#include <stdint.h>
## use conda gcc
# export CC=$HOME/miniforge3/bin/gcc
#export CXX=$HOME/miniforge3/bin/g++
#./lite/tools/build_linux.sh --arch=armv8 --with_extra=ON --toolchain=gcc
# use system clang which is much better than gcc (inference performance)
export CXX=/usr/bin/clang++-16
export CC=/usr/bin/clang-16
./lite/tools/build_linux.sh --arch=armv8 --with_extra=ON --toolchain=clang
compile and execute demo
export PDLITE_INC=$HOME/work/Paddle-Lite/build.lite.linux.armv8.gcc/inference_lite_lib.armlinux.armv8/cxx/include
export PDLITE_LIB=$HOME/work/Paddle-Lite/build.lite.linux.armv8.gcc/inference_lite_lib.armlinux.armv8/cxx/lib
g++ -O3 -o pdlite_perf pdlite_perf.cpp utils.cpp -std=c++17 -I$PDLITE_INC -L$PDLITE_LIB -lpaddle_light_api_shared `pkg-config --cflags --libs opencv4`
LD_LIBRARY_PATH=$PDLITE_LIB ./pdlite_perf --only-test=efficientformerv2_s0 #2>/dev/null
-
https://www.paddlepaddle.org.cn/lite/v2.12/demo_guides/opencl.html
-
部署时需考虑不支持 OpenCL 的情况,可预先使用 API bool ::IsOpenCLBackendValid() 判断,对于不支持的情况加载 CPU 模型
-
对性能不满足需求的场景,可以考虑使用调优 API config.set_opencl_tune(CL_TUNE_NORMAL),首次会有一定的初始化耗时
-
对精度要求较高的场景,可以考虑通过 API config.set_opencl_precision(CL_PRECISION_FP32) 强制使用 FP32 精度
-
对首次加载耗时慢的问题,可以考虑使用 API config.set_opencl_binary_path_name(bin_path, bin_name),提高首次推理时
-
Paddle Lite OpenCL 后端代码尚未完全支持动态 shape,因此在运行动态 shape 的模型时可能会报错。
-
使用 OpenCL 后端进行部署时,模型推理速度并不一定会比在 CPU 上执行快。GPU 适合运行较大计算强度的负载任务,如果模型本身的单位算子计算密度较低,则有可能出现 GPU 推理速度不及 CPU 的情况。在面向 GPU 设计模型结构时,需要尽量减少低计算密度算子的数量,比如 slice、concat 等,具体可参见使用 GPU 获取最佳性能中的【优化建议】章节。
-
设置 OpenCL 混合内存对象推理 OpenCL 大部分算子支持 cl::Image2D 数据排布,少部分算子支持 cl::Buffer(正在持续扩充),出于以下背景原因考虑
- 不同的设备采用 cl::Image2D 和 cl::Buffer 性能优势不同。
- 设备本身对 cl::Image2D 的 CL_DEVICE_IMAGE2D_MAX_HEIGHT 和 CL_DEVICE_IMAGE2D_MAX_WIDTH 有限制,导致部分 op 尺寸过大时会报错:malloc image is out of max image size(w,h)。
- 部分 op 采用 cl::Buffer 内存对象会有很好的性能,比如 reshape,transpose,keep_dims 为 false 的 argmax,reduce 等。 支持两种内存对象可配置,通过环境变量 OPENCL_MEMORY_CONFIG_FILE 设置『OpenCL 内存对象配置文件』,实现人为指定部分 op使用 cl::Buffer 实现;
-
设置 OpenCL 与 CPU 异构推理 对于 cl::Image2D 和 cl::Buffer 均无法支持或者性能差的算子,可以人为指定部分 op 跑 CPU 的实现,可通过环境变量 OPENCL_MEMORY_CONFIG_FILE 设置『OpenCL 内存对象配置文件』实现。 如下的例子使用 benchmark 工具,输入为 PaddlePaddle 的部署模型格式,网络模型为 ch_PP-OCRv3_rec_infer,其中 conv2d,depthwise_conv2d 和 pool2d 三个 op 指定为跑 CPU 实现,剩余 op 跑 OpenCL 后端默认实现(大部分为 cl::Image2D)。
build patch/commands
diff --git a/cmake/os/android.cmake b/cmake/os/android.cmake
index 79fdbd8..7de771f 100644
--- a/cmake/os/android.cmake
+++ b/cmake/os/android.cmake
@@ -16,6 +16,7 @@ set(ANDROID TRUE)
set(ANDROID_ARCH_ABI_LIST "arm64-v8a" "armeabi-v7a" "armeabi-v6" "armeabi" "mips" "mips64" "x86" "x86_64")
set(ANDROID_STL_TYPE_LIST "c++_static" "gnustl_static" "c++_shared")
+set(ANDROID_NDK xxx/android-ndk-r22b)
# Android ndk
if(NOT DEFINED ANDROID_NDK)
set(ANDROID_NDK $ENV{NDK_ROOT})
export ANDROID_NDK=xxx/android-ndk-r22b
./lite/tools/build_android.sh --with_extra=ON --with_opencl=ON --toolchain=clang
use a x86-64 machine to convert the model, in order to avoid building paddlepaddle...
-
x2paddle:
pip install x2paddle
-
paddlepaddle:
pip install paddlepaddle
- paddlelite (attention: 必须和运行的库的版本一致!,不然edgenext模型跑不过)
how to build paddlelite for x86 python
cd third-party/protobuf-host
diff --git a/src/google/protobuf/compiler/java/java_file.cc b/src/google/protobuf/compiler/java/java_file.cc
index 3cbc530eb..da86762e0 100644
--- a/src/google/protobuf/compiler/java/java_file.cc
+++ b/src/google/protobuf/compiler/java/java_file.cc
@@ -65,7 +65,7 @@ namespace java {
namespace {
struct FieldDescriptorCompare {
- bool operator ()(const FieldDescriptor* f1, const FieldDescriptor* f2) {
+ bool operator ()(const FieldDescriptor* f1, const FieldDescriptor* f2) const{
if(f1 == NULL) {
return false;
}
# here we use conda python 3.8
conda install patchelf
git clone --depth=1 https://github.com/PaddlePaddle/Paddle-Lite.git
# conda gcc/g++ 12.3.0 doesn't work!
###export CC=/usr/bin/gcc
###export CXX=/usr/bin/g++
export CXXFLAGS="-Wno-error=array-bounds -Wno-error=pessimizing-move"
# $HOME/miniforge3/envs/py3.8/lib/python3.8/site-packages/setuptools/_vendor/packaging/version.py:196
# version = '0.0+e241420'
cd Paddle-Lite
./lite/tools/build_linux.sh --arch=x86 --with_python=ON --with_extra=ON
# cd build*
# make publish_inference
find . | grep "\.whl"
pip install xxxxx.whl
./lite/tools/build.sh build_optimize_tool
pip install six requests onnx==1.14
- efficientformerv2 ✅
- SwiftFormer
❌runtime error!
- --opset-version >= 9
[F 9/ 4 0: 3:25. 51 ...e-Lite/lite/kernels/host/cast_compute.cc:164 Run] other has not been implemented transform with dtype3 X, dtype0 Out
- EMO ✅ see
convert-tools/emo.patch
mod src code
diff --git a/python/sota/emo.py b/python/sota/emo.py
index ebe3c9e..562a4d5 100644
--- a/python/sota/emo.py
+++ b/python/sota/emo.py
@@ -175,16 +175,7 @@ class iRMB(nn.Module):
x = self.norm(x)
B, C, H, W = x.shape
if self.attn_s:
- # padding
- if self.window_size <= 0:
- window_size_W, window_size_H = W, H
- else:
- window_size_W, window_size_H = self.window_size, self.window_size
- pad_l, pad_t = 0, 0
- pad_r = (window_size_W - W % window_size_W) % window_size_W
- pad_b = (window_size_H - H % window_size_H) % window_size_H
- x = F.pad(x, (pad_l, pad_r, pad_t, pad_b, 0, 0,))
- n1, n2 = (H + pad_b) // window_size_H, (W + pad_r) // window_size_W
+ n1, n2 = H // self.window_size, W // self.window_size
x = rearrange(x, 'b c (h1 n1) (w1 n2) -> (b n1 n2) c h1 w1', n1=n1, n2=n2).contiguous()
# attention
b, c, h, w = x.shape
@@ -206,8 +197,6 @@ class iRMB(nn.Module):
x_spa = rearrange(x_spa, 'b heads (h w) dim_head -> b (heads dim_head) h w', heads=self.num_head, h=h, w=w).contiguous()
# unpadding
x = rearrange(x_spa, '(b n1 n2) c h1 w1 -> b c (h1 n1) (w1 n2)', n1=n1, n2=n2).contiguous()
- if pad_r > 0 or pad_b > 0:
- x = x[:, :, :H, :W].contiguous()
else:
x = self.v(x)
-
edgenext ✅
- opencl ❌runtime error!
segmentation fault
- opencl ❌runtime error!
-
mobilevitv2
❌conversion error!
[F 9/ 3 1:42:33.691 ...rk/work/Paddle-Lite/lite/core/op_lite.cc:176 AttachOutput] Check failed: is_dispensable || is_have_output:
FatalError: `Process abort signal` is detected by the operating system.
[TimeInfo: *** Aborted at 1693676553 (unix time) try "date -d @1693676553" if you are using GNU date ***]
[SignalInfo: *** SIGABRT (@0x3e800001a20) received by PID 6688 (TID 0x7fd501eaf740) from PID 6688 ***]
通过将--opset-version == 8
发现貌似是因为不支持nn.GroupNorm所致🤷
- mobilevit
- --opset-version == 9 ✅
- --opset-version > 9
❌runtime error!
[F 9/ 4 0:23:55.496 ...p/Paddle-Lite/lite/operators/slice_op.cc:47 InferShapeImpl] Check failed: (param_.axes[i] < in_dims.size()): -1!<3 The index of dimension in axes must be less than the size of input shape.
- LeViT ✅
- use fuse option
- by adding HardSwish in x2paddle
#/home/albert/miniforge3/envs/py3.8/lib/python3.8/site-packages/x2paddle/op_mapper/onnx2paddle/opset_legacy.py:179
'HardSwish': ['paddle.nn.Hardswish'],
import torch
from x2paddle.convert import pytorch2paddle
pytorch2paddle(module=torch.jit.load('.pt/efficientformerv2_s0.pt'),
save_dir="./pd_model",
jit_type="trace",
input_examples=[torch.randn(1,3,224,224)],
enable_code_optim=False,
convert_to_lite=True,
lite_valid_places="arm",
lite_model_type="naive_buffer")
- efficientformerv2 ✅
- SwiftFormer
❌convert error!
========= 1 OPs are not supported yet ===========
========== aten::linalg_vector_norm ============
- edgenext
❌convert error!
========= 1 OPs are not supported yet ===========
========== aten::linalg_vector_norm ============
- mobilevitv2
❌convert error!
Error: This model is not supported, because 1 ops are not supported on 'arm'. These unsupported ops are: 'expand_as_v2'.
--------------------------------------
C++ Traceback (most recent call last):
--------------------------------------
0 paddle::lite_api::OptBase::Run()
1 paddle::lite_api::OptBase::CheckIfModelSupported(bool)
----------------------
Error Message Summary:
----------------------
FatalError: `Process abort signal` is detected by the operating system.
[TimeInfo: *** Aborted at 1694800212 (unix time) try "date -d @1694800212" if you are using GNU date ***]
[SignalInfo: *** SIGABRT (@0x3e80000504d) received by PID 20557 (TID 0x7fc797012740) from PID 20557 ***]
[1] 20557 abort (core dumped) ipython
- mobilevit
❌convert error!
[Hint: Expected capacity == in_size, but received capacity:65536 != in_size:50176.] (at /paddle/paddle/fluid/operators/reshape_op.cc:234)
[operator < reshape2 > error]
- LeViT
❌convert error!
========= 1 OPs are not supported yet ===========
========== aten::reshape_as ============
- tf_efficientnetv2
❌ convert error!
========= 2 OPs are not supported yet ===========
========== aten::ceil ============
========== aten::pad ============
动态离线量化,将模型中特定 OP 的权重从 FP32 类型量化成 INT8/16 类型。 该量化模型有两种预测方式:
- 第一种是反量化预测方式,即是首先将 INT8/16 类型的权重反量化成 FP32 类型,然后再使用 FP32 浮运算运算进行预测;
- 第二种量化预测方式,即是预测中动态计算量化 OP 输入的量化信息,基于量化的输入和权重进行 INT8 整形运算。
注意:目前 Paddle Lite 仅支持第一种反量化预测方式。
use gcc-10 to build opt
./lite/tools/build.sh build_optimize_tool

CPU/GPU FP32
Model | Top-1 | Top-1 //20 est. |
Top-1 //50 est. |
#params | GMACs |
---|---|---|---|---|---|
efficientformerv2_s0 | - | 75.9 | 75.9 | 3.5M | 0.40G |
efficientformerv2_s1 | - | 78.8 | 80.2 | 6.1M | 0.65G |
efficientformerv2_s2 | - | 82.1 | 81.8 | 12.6M | 1.25G |
edgenext_xx_small | - | 70.8 | 70.7 | 1.3M | 0.26G |
edgenext_x_small | - | 74.8 | 74.8 | 2.3M | 0.54G |
edgenext_small/usi | - | 80.6 | 79.9 | 5.6M | 1.26G |
mobilevit_xx_small | - | 68.9 | 66.5 | 1.3M | 0.36G |
mobilevit_x_small | - | 74.0 | 73.7 | 2.3M | 0.89G |
mobilevit_small | - | 77.6 | 77.9 | 5.6M | 2.0 G |
LeViT_128S | - | 75.9 | 76.1 | 7.8M | 0.30G |
LeViT_128 | - | 79.4 | 78.1 | 9.2M | 0.41G |
LeViT_192 | - | 79.6 | 79.6 | 11 M | 0.66G |
LeViT_256 | - | 81.1 | 81.4 | 19 M | 1.12G |
resnet50 | - | 79.6 | 81.3 | 25.6M | 4.1G |
mobilenetv3_large_100 | - | 75.6 | 75.3 | 5.5M | 0.29G |
目前X2Paddle支持90+ TensorFlow OP,30+ Caffe OP,90+ ONNX OP,130+ PyTorch OP,覆盖了大部分CV分类模型常用的操作。我们在如下列表中给出了目前X2Paddle支持的全部OP。