mnn - YingkunZhou/EdgeTransformerBench GitHub Wiki

TODO:

  • 找到错误的算子,把它们回退到CPU上执行
  • int precision = 2;
  • 一些惊人的发现:
    • mnn模型和mnn的版本是绑定的,不能做到向前兼容,但能做到向后兼容,也即旧的版本转换的模型能在新的版本的推理库上运行,但是新的版本转换的模型却不能在旧的版本的推理库上运行!

support matrix

System CPU Vulkan OpenCL OpenGL
Raspberry Pi 4B mesa ❌ don't support! -
Edge2 Linux ❌ don't support! ✅ 77+603
Orin Linux ✅ (128+)broadcast ❌ don't support!
Vim3 Linux panfrost ❌ GPU driver bug!
Vim3 Android ❌ GPU driver bug! ❌ GPU driver bug!
Edge2 Android ✅ (128+)broadcast ✅ 77+603
Samsung S20+(Exynos 990 (7 nm+)) ✅ (128+)broadcast ✅ 77+603
Xiaomi MIX 2S(SDM845 (10 nm)) ✅ (128+)broadcast ✅ 603?
目前为了使MNN在各个设备上能够跑通Vulkan和OpenCL后端(是的,他们仍有bug),主要的修改patch如下,不同的测试设备和不同的网络模型需要做不同的调整修改:
diff --git a/source/backend/opencl/core/OpenCLBackend.cpp b/source/backend/opencl/core/OpenCLBackend.cpp
index d8b3775..c9d4617 100644
--- a/source/backend/opencl/core/OpenCLBackend.cpp
+++ b/source/backend/opencl/core/OpenCLBackend.cpp
@@ -428,7 +428,7 @@ Execution* OpenCLBackend::onCreate(const std::vector<Tensor*>& inputs, const std
     auto creators = gCreator();
     auto iter      = creators->find(std::make_pair(op->type(), mOpenCLRuntime->getGpuMemType()));

-    if (iter == creators->end()) {
+    if (iter == creators->end() || op->type() == 77) {
         mOpenCLRuntime->setDevideOpRecord();
         #if 0//close log
         if (nullptr != op->name()) {
diff --git a/source/backend/vulkan/image/execution/VulkanBinary.cpp b/source/backend/vulkan/image/execution/VulkanBinary.cpp
index b3f9eba..655b067 100644
--- a/source/backend/vulkan/image/execution/VulkanBinary.cpp
+++ b/source/backend/vulkan/image/execution/VulkanBinary.cpp
@@ -170,6 +170,10 @@ public:
     virtual VulkanBasicExecution* onCreate(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs, const MNN::Op* op,
                                 Backend* backend) const override {
         auto input0 = inputs[0];
+        for (int i = 1; i < inputs.size(); ++i) {
+            auto input = inputs[i];
+            if (input0->dimensions()  + input->dimensions() == 7) return nullptr;
+        }
         auto image = TensorUtils::getDescribe(input0)->dimensionFormat == MNN_DATA_FORMAT_NC4HW4;
         auto shader = _getShaderName(op, image);
         if (shader.empty()) {
how to build
git clone https://github.com/alibaba/MNN.git #--depth=1
cd MNN
# git checkout d8266f9697650d4a90cccea337c2ae3ee070c373
# the latest commit id is c442ff39ec9a6a99c28bddd465d8074a7b5c1cca --> tag: 2.6.3
# 可选,修改 MNN Schema 后需要)
./schema/generate.sh
mkdir -p build && build
# Edge2 Linux
cmake -D CMAKE_BUILD_TYPE=Release -D MNN_VULKAN=ON -D MNN_OPENCL=ON .. \
-D CMAKE_INSTALL_PREFIX=../install -D MNN_SEP_BUILD=OFF #-D MNN_OPENGL=ON
make install -j`nproc`
--------------------------------------------------------------------------------------------------------------
# android
## way 1: native build
pkg install mesa-dev # for opengl
cmake -D CMAKE_BUILD_TYPE=Release -D MNN_USE_LOGCAT=false -D MNN_VULKAN=ON -D MNN_OPENCL=ON .. \
-D CMAKE_INSTALL_PREFIX=../install -DMNN_BUILD_FOR_ANDROID_COMMAND=true -DNATIVE_LIBRARY_OUTPUT=. -DNATIVE_INCLUDE_OUTPUT=.  -D MNN_SEP_BUILD=OFF #-D MNN_OPENGL=ON
make install -j`nproc`
## way 2: cross build
cd project/android
vim build_64.sh
#######################################################
#!/bin/bash
cmake ../../../ \
-DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \
-DCMAKE_BUILD_TYPE=Release \
-DANDROID_ABI="arm64-v8a" \
-DMNN_USE_LOGCAT=false \
-DANDROID_PLATFORM=android-24  \
-DMNN_BUILD_FOR_ANDROID_COMMAND=true \
-D MNN_OPENCL=ON -D MNN_VULKAN=ON -D MNN_OPENGL=ON \
-D MNN_SEP_BUILD=OFF -D CMAKE_INSTALL_PREFIX=../install-mnn \
-DNATIVE_LIBRARY_OUTPUT=. -DNATIVE_INCLUDE_OUTPUT=.

make install -j32
#######################################################
export ANDROID_NDK=<prefix>/android-ndk-r22b
mkdir build_64 && cd build_64 && ../build_64.sh
**注:在Jetson orin上最新的MNN需要用clang才能编译成功**
export CC=clang
export CXX=clang++
## set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -stdlib=libc++")
模型转换
-D MNN_BUILD_CONVERTER=ON

./MNNConvert -f ONNX --modelFile ../onnx/debug.onnx --MNNModel debug.mnn --bizCode MNN

Linux

Vulkan

sudo apt install libvulkan-dev # to install /usr/lib/aarch64-linux-gnu/libvulkan.so which mnn needed!

OpenGL (Edge2)

the patch to build
diff --git a/source/backend/opengl/GLHead.hpp b/source/backend/opengl/GLHead.hpp
index 96be807..348f9cf 100644
--- a/source/backend/opengl/GLHead.hpp
+++ b/source/backend/opengl/GLHead.hpp
@@ -17,7 +17,7 @@
 #define CONTEXT_FREE_API
 #include <assert.h>
 #include <stdlib.h>
-#ifdef __ANDROID__
+#ifndef __ANDROID__
 #include <GLES2/gl2.h>
 #include <GLES2/gl2ext.h>
 #include <GLES3/gl31.h>
diff --git a/source/shape/ShapePool3D.cpp b/source/shape/ShapePool3D.cpp
index a346b79..406795a 100644
--- a/source/shape/ShapePool3D.cpp
+++ b/source/shape/ShapePool3D.cpp
@@ -77,6 +77,7 @@ public:
         auto size  = (float)outputs[0]->elementSize() / 1024.0f / 1024.0f;
         auto layer = op->main_as_Pool3D();
         float flopsPerElement = 1;
+        return size * flopsPerElement;
         for (auto kernel: *layer->kernels()) {
             flopsPerElement *= kernel;
         }
尝试编译
mkdir -p build && build
git clone https:gitlab.freedesktop.org/panfrost/mesa.git --depth=1
sudo apt install libx11-dev
export CPLUS_INCLUDE_PATH=$PWD/mesa/include
export LIBRARY_PATH=$HOME/work/opengl
$ ll ~/work/opengl/
total 84M
lrwxrwxrwx 1 albert albert  11 Aug 22 12:06 libEGL.so -> libEGL.so.1
lrwxrwxrwx 1 albert albert  15 Aug 22 11:43 libEGL.so.1 -> libEGL.so.1.1.0
lrwxrwxrwx 1 albert albert  42 Aug 22 19:36 libEGL.so.1.1.0 -> /usr/lib/aarch64-linux-gnu/libEGL.so.1.1.0
lrwxrwxrwx 1 albert albert  14 Aug 22 12:06 libGLESv3.so -> libGLESv3.so.3
lrwxrwxrwx 1 albert albert  18 Aug 22 11:42 libGLESv3.so.3 -> libGLESv3.so.3.1.0
lrwxrwxrwx 1 albert albert  12 Aug 22 19:36 libGLESv3.so.3.1.0 -> libmali.so.1
lrwxrwxrwx 1 albert albert  12 Aug 22 11:40 libmali.so -> libmali.so.1
lrwxrwxrwx 1 albert albert  37 Aug 24 11:06 libmali.so.1 -> libmali-valhall-g610-g15p0-wayland.so
-rw-r--r-- 1 albert albert 42M Aug 24 00:25 libmali-valhall-g610-g13p0-x11-wayland-gbm.so
-rw-r--r-- 1 albert albert 42M Aug 22 11:38 libmali-valhall-g610-g15p0-wayland.so
cmake -D CMAKE_BUILD_TYPE=Release -D MNN_OPENGL=ON .. -D CMAKE_INSTALL_PREFIX=.. /install-opengl -D MNN_SEP_BUILD=OFF
make install -j10
compile and execute demo
#export OPENCV_LIB=$HOME/miniforge3/envs/py3.8/lib
#export OPENCV_INC=$HOME/miniforge3/envs/py3.8/include/opencv4
#export MNN_LIB=$HOME/work/MNN/install-opengl/lib
#export MNN_INC=$HOME/work/MNN/install-opengl/include

#g++ -O3 -o mnn_perf mnn_perf.cpp utils.cpp  -std=c++17 \
#    -I$MNN_INC -I$OPENCV_INC -L$MNN_LIB -L$OPENCV_LIB \
#    -lMNN_GL -lMNN -lopencv_imgproc -lopencv_imgcodecs -lopencv_core -lopencv_dnn

# for linux
sudo apt install opencv
# for android
pkg install opencv

g++ -O3 -o mnn_perf mnn_perf.cpp utils.cpp  -std=c++17 -I$MNN_INC -L$MNN_LIB -lMNN `pkg-config --cflags --libs opencv4`
g++ -O3 -DTEST -o mnn_perf-test mnn_perf.cpp utils.cpp  -std=c++17 -I$MNN_INC -L$MNN_LIB -lMNN `pkg-config --cflags --libs opencv4`
g++ -g -DDEBUG -o mnn_perf-debug mnn_perf.cpp utils.cpp  -std=c++17 -I$MNN_INC -L$MNN_LIB -lMNN `pkg-config --cflags --libs opencv4`

# export LD_LIBRARY_PATH=$OPENCV_LIB:$MNN_LIB
export LD_LIBRARY_PATH=$MNN_LIB
**panfrost opengl** seems cannot run!!
$ ./mnn_perf --only-test res --backend g
Creating MNN Interpreter: resnet50
The device support i8sdot:1, support fp16:1, support i8mm: 0
eglChooseConfig error !!!
mContext error !!!
terminate called after throwing an instance of 'std::logic_error'
  what():  basic_string: construction from null is not valid
[1]    174026 abort      ./mnn_perf --only-test res --backend g

OpenGL debug journey (edge2)

first reference to glmark2-es-wayland

we get following patch
diff --git a/source/backend/opengl/GLContext.cpp b/source/backend/opengl/GLContext.cpp
index 53be996..b09f227 100644
--- a/source/backend/opengl/GLContext.cpp
+++ b/source/backend/opengl/GLContext.cpp
@@ -7,11 +7,24 @@
 //

 #include "backend/opengl/GLContext.hpp"
+#include <wayland-client.h>
+#include <EGL/eglext.h>
 namespace MNN {
 namespace OpenGL {
     GLContext::GLContext() {
         if(!(eglGetCurrentContext() != EGL_NO_CONTEXT)){
+#if 1
+            wl_display *display = wl_display_connect(NULL);
+            assert(display);
+            wl_display_roundtrip(display);
+            mDisplay = eglGetPlatformDisplay(
+                EGL_PLATFORM_WAYLAND_KHR,
+                reinterpret_cast<void*>(display),
+                nullptr
+            );
+#else
             mDisplay = eglGetDisplay(EGL_DEFAULT_DISPLAY);
+#endif
             if (mDisplay == EGL_NO_DISPLAY) {
                 MNN_PRINT("eglGetDisplay error !!! \n");
                 mIsCreateError = true;
@@ -20,8 +33,12 @@ namespace OpenGL {
             int minorVersion;
             eglInitialize(mDisplay, &majorVersion, &minorVersion);
             EGLint numConfigs;
+#if 1
+            static const EGLint configAttribs[] = {
+#else
             static const EGLint configAttribs[] = {EGL_SURFACE_TYPE,
                                                 EGL_PBUFFER_BIT,
+#endif
                                                 EGL_RENDERABLE_TYPE,
                                                 EGL_OPENGL_ES2_BIT,
                                                 EGL_RED_SIZE,
also need to modify cmake config file
# CMakeLists.txt
list(APPEND MNN_EXTRA_DEPENDS wayland-client)
# source/backend/opengl/CMakeLists.txt
target_link_libraries(MNN_GL MNN GLESv3 wayland-client EGL)
try to use wayland display, and then works
$ ./mnn_perf-test --only-test=res --backend=g
Creating MNN Interpreter: resnet50
The device support i8sdot:1, support fp16:1, support i8mm: 0
gpu type : Mali-G610 (Panfrost)
gl version : OpenGL ES 3.1 Mesa 23.0.0-devel
gpu type : Mali-G610 (Panfrost)
gl version : OpenGL ES 3.1 Mesa 23.0.0-devel
Don't support type 112, /global_pool/pool/GlobalAveragePool_output_0
Don't support type [Pooling3D], /global_pool/pool/GlobalAveragePool_output_0
Create exection error : 112
Can't run session because not resized
(index: 999,  score: 0.000000), (index: 998,  score: 0.000000), (index: 997,  score: 0.000000),

这个Don't support type [Pooling3D]就很怪,因为opencl明明是支持的,opengl和opencl的后段又有什么区别呢?

打开打印语句
diff --git a/source/backend/opencl/core/OpenCLBackend.cpp b/source/backend/opencl/core/OpenCLBackend.cpp
index 20942e0..7d285e1 100644
--- a/source/backend/opencl/core/OpenCLBackend.cpp
+++ b/source/backend/opencl/core/OpenCLBackend.cpp
@@ -405,9 +405,10 @@ Execution* OpenCLBackend::onCreate(const std::vector<Tensor*>& inputs, const std
 #endif
     auto creators = gCreator();
     auto iter      = creators->find(std::make_pair(op->type(), mOpenCLRuntime->getGpuMemType()));
+    MNN_PRINT("Start OpenCLBackend::onCreate %d\n", op->type());

     if (iter == creators->end()) {
-        #if 0//close log
+        #if 1//close log
         if (nullptr != op->name()) {
             MNN_PRINT("Don't support type %s memObject:%d, %s\n", EnumNameOpType(op->type()), mOpenCLRuntime->getGpuMemType(), op->name()->c_str());
         } else {
diff --git a/source/backend/opengl/GLBackend.cpp b/source/backend/opengl/GLBackend.cpp
index b091237..c421c70 100644
--- a/source/backend/opengl/GLBackend.cpp
+++ b/source/backend/opengl/GLBackend.cpp
@@ -290,6 +290,7 @@ void GLBackend::upload(GLuint textureId, const float *inputData, int width, int

 Execution *GLBackend::onCreate(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs,
                                const MNN::Op *op) {
+    MNN_PRINT("Start OpenCLBackend::onCreate %d\n", op->type());
     auto map  = gCreator();
     auto iter = map->find(op->type());
     if (iter == map->end()) {
@@ -469,6 +470,7 @@ int GLRuntime::onGetRuntimeStatus(RuntimeStatus statusEnum) const {
 }

 Runtime::CompilerType GLRuntime::onGetCompilerType() const {
+    return Compiler_Loop;
     return Compiler_Origin;
 }

可以看出opencl在原本应该是112的地方使用了128,溯源后发现是应为CompilerType的问题,opencl使用了Compiler_Loop,而opengl使用了Compiler_Origin

把它替换掉,还是会跑出问题:
Start OpenCLBackend::onCreate 128
Don't support type 128
Start OpenCLBackend::onCreate 12
Start OpenCLBackend::onCreate 128
Don't support type 128

https://github.com/alibaba/MNN/blob/master/schema/current/MNN_generated.h

  OpType_Pooling3D = 112,
  OpType_Convolution3D = 113,
  OpType_MatrixBandPart = 114,
  OpType_GatherND = 115,
  OpType_DetectionPostProcess = 116,
  OpType_UnravelIndex = 117,
  OpType_ScatterNd = 118,
  OpType_OneHot = 119,
  OpType_BroadcastTo = 120,
  OpType_Dilation2D = 121,
  OpType_Interp3D = 122,
  OpType_Raster = 128,

然后进过文件名的grep之后

$ find source | grep Raster
source/backend/cpu/CPURaster.hpp
source/backend/cpu/CPURaster.cpp
source/backend/vulkan/buffer/execution/VulkanRaster.cpp
source/backend/vulkan/buffer/execution/VulkanRaster.hpp
source/backend/vulkan/image/execution/VulkanRaster.cpp
source/backend/vulkan/image/execution/VulkanRaster.hpp
source/backend/cuda/execution/Raster.cuh
source/backend/cuda/execution/RasterExecution.cpp
source/backend/cuda/execution/RasterExecution.hpp
source/backend/cuda/execution/Raster.cu
source/backend/metal/MetalRaster.mm
source/backend/metal/MetalRaster.hpp
source/backend/opencl/execution/buffer/RasterBufExecution.cpp
source/backend/opencl/execution/buffer/RasterBufExecution.hpp
source/backend/opencl/execution/image/RasterExecution.cpp
source/backend/opencl/execution/image/RasterExecution.hpp
source/backend/nnapi/execution/NNAPIRaster.cpp
source/backend/nnapi/execution/NNAPIRaster.hpp
source/backend/coreml/execution/CoreMLRaster.cpp
source/backend/coreml/execution/CoreMLRaster.hpp
source/backend/coreml/backend/CoreMLRaster.metal
source/backend/tensorrt/execution/TRTRaster.cpp
source/backend/tensorrt/execution/plugin/Raster.cuh
source/backend/tensorrt/execution/plugin/RasterPlugin.hpp
source/backend/tensorrt/execution/plugin/RasterPlugin.cpp
source/backend/tensorrt/execution/plugin/Raster.cu
source/backend/tensorrt/execution/TRTRaster.hpp

这就很尴尬了,opengl没给Raster算子的实现!!!

3、如果你们在 opencl 后端额外做了很多优化的话,可以把它的 Compile_Type 设成 Origin ,类似当前的 OpenGL 后端配置,这样不会进行几何计算的拆解

这里还有两个细节需要注意:

  1. 最基础的CPU backend也没有支持OpType_Pooling3D
通过grep可以得知:
./tools/converter/source/tensorflow/Pooling3DTf.cpp:16:    return MNN::OpType_Pooling3D;
./tools/converter/source/optimizer/postconvert/AddTensorFormatConverter.cpp:41:        case MNN::OpType_Pooling3D:
./tools/converter/source/optimizer/onnxextra/OnnxPooling.cpp:56:        dstOp->type       = MNN::OpType_Pooling3D;
./tools/converter/source/caffe/Pool.cpp:120:        return MNN::OpType_Pooling3D;
./schema/current/MNN_generated.h:201:  OpType_Pooling3D = 112,
./schema/current/MNN_generated.h:382:    OpType_Pooling3D,
./source/geometry/GeometryPooling3D.cpp:146:    GeometryComputer::registerGeometryComputer(comp, {OpType_Pooling3D});
./source/shape/ShapePool3D.cpp:88:REGISTER_SHAPE(Pool3DSizeComputer, OpType_Pooling3D);
./source/shape/ShapeRegister.cpp:70:extern void ___Pool3DSizeComputer__OpType_Pooling3D__();
./source/shape/ShapeRegister.cpp:180:___Pool3DSizeComputer__OpType_Pooling3D__();
./test/op/Pool3DTest.cpp:29:    op->type       = OpType_Pooling3D;
  1. 虽然opengl不支持Raster算子,但是可以回退到CPU backend进行执行,所以最后的结果按理来说没有问题
参见如下代码逻辑:
            if (nullptr == iter.execution) {
                iter.execution.reset(mBackend->onCreate(iter.inputs, iter.outputs, iter.op));
            }
            if (nullptr == iter.execution) {
                // Try Backup
                iter.execution.reset(mBackupBackend->onCreate(iter.inputs, iter.outputs, iter.op));
                if (nullptr == iter.execution) {
                    MNN_ERROR("Create exection error : %d\n", iter.op->type());
                    return NOT_SUPPORT;
                }
            }
但是比较迷的是resnet还是跑错了!
# libmali.so.1 -> /usr/lib/aarch64-linux-gnu/libGL.so.1.7.0
$ ./mnn_perf-test --only-test=res --backend=g
Creating MNN Interpreter: resnet50
The device support i8sdot:1, support fp16:1, support i8mm: 0
gpu type : Mali-G610 (Panfrost)
gl version : OpenGL ES 3.1 Mesa 23.0.0-devel
gpu type : Mali-G610 (Panfrost)
gl version : OpenGL ES 3.1 Mesa 23.0.0-devel
(index: 833,  score: 2.369141), (index: 921,  score: 1.832031), (index: 742,  score: 1.584961),
$ ./mnn_perf --only-test=res --backend=g
Creating MNN Interpreter: resnet50
The device support i8sdot:1, support fp16:1, support i8mm: 0
gpu type : Mali-G610 (Panfrost)
gl version : OpenGL ES 3.1 Mesa 23.0.0-devel
gpu type : Mali-G610 (Panfrost)
gl version : OpenGL ES 3.1 Mesa 23.0.0-devel
(index: 673,  score: 797.500000), (index: 396,  score: 547.500000), (index: 366,  score: 367.500000),
[63 iters] min = 315.68ms max = 331.70ms median = 316.48ms mean = 318.44ms mean = 318.93ms

而且调整推理的次数,结果还不一样。。。。

切换到libmali-valhall-g610-g15p0-wayland.so就work了!但其实g15是g6,名字并不准确!
# libmali.so.1 -> libmali-valhall-g610-g15p0-wayland.so
$ ./mnn_perf --only-test=res --backend=g
Creating MNN Interpreter: resnet50
The device support i8sdot:1, support fp16:1, support i8mm: 0
arm_release_ver of this libmali is 'g6p0-01eac0', rk_so_ver is '6'.
gpu type : Mali-LODX
gl version : OpenGL ES 3.2 v1.g6p0-01eac0.ba52c908d926792b8f5fe28f383a2b03
gpu type : Mali-LODX
gl version : OpenGL ES 3.2 v1.g6p0-01eac0.ba52c908d926792b8f5fe28f383a2b03
(index: 985,  score: 7.925781), (index: 113,  score: -5.296875), (index: 308,  score: -5.500000),
[274 iters] min =  72.53ms max =  80.66ms median =  73.02ms mean =  73.04ms mean =  73.20ms
另外需要注意的是切换到libmali-valhall-g610-g13p0-x11-wayland-gbm.so也是不行的
# libmali.so.1 -> libmali-valhall-g610-g13p0-x11-wayland-gbm.so
$ ./mnn_perf --only-test=res --backend=g
Creating MNN Interpreter: resnet50
The device support i8sdot:1, support fp16:1, support i8mm: 0
arm_release_ver: g13p0-01eac0, rk_so_ver: 3
eglChooseConfig error !!!
mContext error !!!
[1]    128017 segmentation fault  ./mnn_perf --only-test=res --backend=g

OpenCL (Edge2)

apt install clinfo and some dev library first
$ apt search ocl-icd
Sorting... Done
Full Text Search... Done
ocl-icd-dev/jammy,now 2.2.14-3 arm64 [installed]
  Development files to build an OpenCL ICD

ocl-icd-libopencl1/jammy,now 2.2.14-3 arm64 [installed]
  Generic OpenCL ICD Loader

ocl-icd-opencl-dev/jammy,now 2.2.14-3 arm64 [installed]
  OpenCL development files

sudo apt-get install clinfo ocl-icd-libopencl1 ocl-icd-dev ocl-icd-opencl-dev libhwloc-dev

OpenCL debug journey (edge2)

把仓库更新到最新之后,一切都变得好了起来!!!

魔改Levit
diff --git a/levit.py b/levit.py
index 7fa515d..6b9ebef 100644
--- a/levit.py
+++ b/levit.py
@@ -147,6 +147,7 @@ class Linear_BN(torch.nn.Sequential):

     def forward(self, x):
         l, bn = self._modules.values()
+        return bn(x)
         x = l(x)
         return bn(x.flatten(0, 1)).reshape_as(x)

@@ -198,6 +199,7 @@ class Residual(torch.nn.Module):
         self.drop = drop

     def forward(self, x):
+        return self.m(x)
         if self.training and self.drop > 0:
             return x + self.m(x) * torch.rand(x.size(0), 1, 1,
                                               device=x.device).ge_(self.drop).div(1 - self.drop).detach()
@@ -255,8 +257,9 @@ class Attention(torch.nn.Module):
             self.ab = self.attention_biases[:, self.attention_bias_idxs]

     def forward(self, x):  # x (B,N,C)
-        B, N, C = x.shape
+        # B, N, C = x.shape
         qkv = self.qkv(x)
+        return qkv
         q, k, v = qkv.view(B, N, self.num_heads, -
                            1).split([self.key_dim, self.key_dim, self.d], dim=3)
         q = q.permute(0, 2, 1, 3)
@@ -464,9 +467,9 @@ class LeViT(torch.nn.Module):
         return {x for x in self.state_dict().keys() if 'attention_biases' in x}

     def forward(self, x):
-        x = self.patch_embed(x)
-        x = x.flatten(2).transpose(1, 2)
-        x = self.blocks(x)
+        x = x.view(196, 256)
+        x = self.blocks[0](x)
+        return x
         x = x.mean(1)
         if self.distillation:
             x = self.head(x), self.head_dist(x)
得到调试网络,结构其实非常简单: image

python convert.py --only-convert LeViT_128S --debug 196,16

作用于各个后端的输出:

$ ./mnn_perf-debug --only-test=debug --backend=g
Creating MNN Interpreter: debug
The device support i8sdot:1, support fp16:1, support i8mm: 0
arm_release_ver of this libmali is 'g6p0-01eac0', rk_so_ver is '6'.
gpu type : Mali-LODX
gl version : OpenGL ES 3.2 v1.g6p0-01eac0.ba52c908d926792b8f5fe28f383a2b03
gpu type : Mali-LODX
gl version : OpenGL ES 3.2 v1.g6p0-01eac0.ba52c908d926792b8f5fe28f383a2b03
Don't support type 128
Don't support type 77, output
[len: 50176] (0: 0.22279) (1: -1.49101) (-2:-1.48576) (-1:1.51551)
$ ./mnn_perf-debug --only-test=debug
Creating MNN Interpreter: debug
The device support i8sdot:1, support fp16:1, support i8mm: 0
[len: 50176] (0: 0.22279) (1: -1.49101) (-2:-1.48576) (-1:1.51551)
$ ./mnn_perf-debug --only-test=debug --backend=o
Creating MNN Interpreter: debug
The device support i8sdot:1, support fp16:1, support i8mm: 0
arm_release_ver: g13p0-01eac0, rk_so_ver: 3
[len: 50176] (0: nan) (1: 2.84766) (-2:-0.341064) (-1:-0.207031)
$ ./mnn_perf-debug --only-test=debug --backend=o
Creating MNN Interpreter: debug
The device support i8sdot:1, support fp16:1, support i8mm: 0
arm_release_ver of this libmali is 'g6p0-01eac0', rk_so_ver is '6'.
[len: 50176] (0: nan) (1: 2.84766) (-2:-0.341064) (-1:-0.207031)

可以看到opencl后端(两个不同的GPU驱动都是)给出的结果是错误的,这在安卓上也是一样的效果,所以大概率不是驱动的问题

目前的解决方案是将77算子back forward到CPU

结合之前在vulkan后端发现的一个reshape的共性bug,有如下patch:
diff --git a/source/backend/opencl/core/OpenCLBackend.cpp b/source/backend/opencl/core/OpenCLBackend.cpp
index 20942e0..e513851 100644
--- a/source/backend/opencl/core/OpenCLBackend.cpp
+++ b/source/backend/opencl/core/OpenCLBackend.cpp
@@ -406,7 +406,7 @@ Execution* OpenCLBackend::onCreate(const std::vector<Tensor*>& inputs, const std
     auto creators = gCreator();
     auto iter      = creators->find(std::make_pair(op->type(), mOpenCLRuntime->getGpuMemType()));

-    if (iter == creators->end()) {
+    if (iter == creators->end() || op->type() == 128 || op->type() == 77) {
         #if 0//close log
         if (nullptr != op->name()) {
             MNN_PRINT("Don't support type %s memObject:%d, %s\n", EnumNameOpType(op->type()), mOpenCLRuntime->getGpuMemType(), op->name()->c_str());

然后在Edge2的Linux上就能够跑出

  • Levit
  • efficientfomerv2
  • Swiftformer

这三个模型的结果了,但是精度有问题,而且通过进一步调试Levit发现不能简单的通过将部分算子back forward到CPU上执行来补回精度,很奇怪。

这里魔改efficientformer可以很好的说明这一点:
  • 一层
diff --git a/efficientformer_v2.py b/efficientformer_v2.py
index 2ee5d2c..16a7169 100644
--- a/efficientformer_v2.py
+++ b/efficientformer_v2.py
@@ -379,18 +379,19 @@ class Mlp(nn.Module):
     def forward(self, x):
         x = self.fc1(x)
         x = self.norm1(x)
+        return x
         x = self.act(x)

         if self.mid_conv:
             x_mid = self.mid(x)
             x_mid = self.mid_norm(x_mid)
             x = self.act(x_mid)
-        x = self.drop(x)
+        #x = self.drop(x)

         x = self.fc2(x)
         x = self.norm2(x)

-        x = self.drop(x)
+        #x = self.drop(x)
         return x


@@ -447,6 +448,7 @@ class FFN(nn.Module):
                 layer_scale_init_value * torch.ones(dim).unsqueeze(-1).unsqueeze(-1), requires_grad=True)

     def forward(self, x):
+        return self.mlp(x)
         if self.use_layer_scale:
             x = x + self.drop_path(self.layer_scale_2 * self.mlp(x))
         else:
@@ -621,7 +623,8 @@ class EfficientFormerV2(nn.Module):
     def forward_tokens(self, x):
         outs = []
         for idx, block in enumerate(self.network):
-            x = block(x)
+            x = block[0](x)
+            return x
             if self.fork_feat and idx in self.out_indices:
                 norm_layer = getattr(self, f'norm{idx}')
                 x_out = norm_layer(x)
@@ -631,8 +634,9 @@ class EfficientFormerV2(nn.Module):
         return x

     def forward(self, x):
-        x = self.patch_embed(x)
+        # x = self.patch_embed(x)
         x = self.forward_tokens(x)
+        return x
         if self.fork_feat:
             # otuput features of four stages for dense prediction
             return x
python convert.py --only-convert efficientformerv2_s0 --debug 32,56
image
$ ./mnn_perf-debug --only-test=debug
Creating MNN Interpreter: debug
The device support i8sdot:1, support fp16:1, support i8mm: 0
[len: 401408] (0: 0.545583) (1: 0.545583) (-2:-6.11306) (-1:-6.11306)
$ ./mnn_perf-debug --only-test=debug --backend=o
Creating MNN Interpreter: debug
The device support i8sdot:1, support fp16:1, support i8mm: 0
arm_release_ver: g13p0-01eac0, rk_so_ver: 3
[len: 401408] (0: 0.544922) (1: 0.544922) (-2:-6.10156) (-1:-6.10156)
  • 两层
image
    def forward(self, x):
        x = self.fc1(x)
        x = self.norm1(x)
        x = self.act(x)

        if self.mid_conv:
            x_mid = self.mid(x)
            x_mid = self.mid_norm(x_mid)
            return x_mid
$ ./mnn_perf-debug --only-test=debug
Creating MNN Interpreter: debug
The device support i8sdot:1, support fp16:1, support i8mm: 0
[len: 401408] (0: 9.74586) (1: 10.3657) (-2:3.49701) (-1:3.49701)
$ ./mnn_perf-debug --only-test=debug --backend=o
Creating MNN Interpreter: debug
The device support i8sdot:1, support fp16:1, support i8mm: 0
arm_release_ver: g13p0-01eac0, rk_so_ver: 3
[len: 401408] (0: 9.72656) (1: 10.3438) (-2:3.49609) (-1:3.49609)
  • 三层
image
    def forward(self, x):
        x = self.fc1(x)
        x = self.norm1(x)
        x = self.act(x)

        if self.mid_conv:
            x_mid = self.mid(x)
            x_mid = self.mid_norm(x_mid)
            x = self.act(x_mid)
        #x = self.drop(x)

        x = self.fc2(x)
        x = self.norm2(x)
$ ./mnn_perf-debug --only-test=debug
Creating MNN Interpreter: debug
The device support i8sdot:1, support fp16:1, support i8mm: 0
[len: 100352] (0: 1.05672) (1: 1.14461) (-2:-0.517402) (-1:-1.34904)
$ ./mnn_perf-debug --only-test=debug --backend=o
Creating MNN Interpreter: debug
The device support i8sdot:1, support fp16:1, support i8mm: 0
arm_release_ver: g13p0-01eac0, rk_so_ver: 3
[len: 100352] (0: 1.0459) (1: 1.13379) (-2:-0.522461) (-1:-1.35449)

可以看到误差就是一层一层之后累计的

另外三个模型

  • EMO
  • edgenext
  • mobilevit
都是因为mali的GPU opencl不支持MNN中layernorm的写法而报错:
Program build log: <source>:53:55: error: scalar operand type has greater rank than the type of the vector element. ('half4' (vector of 4 'half' values) and 'float')
        FLOAT4 value = (FLOAT4)1.0f / sqrt(square_sum + epsilon);
                                           ~~~~~~~~~~ ^ ~~~~~~~

<source>:115:55: error: scalar operand type has greater rank than the type of the vector element. ('half4' (vector of 4 'half' values) and 'float')
        FLOAT4 value = (FLOAT4)1.0f / sqrt(square_sum + epsilon);
                                           ~~~~~~~~~~ ^ ~~~~~~~

<source>:204:55: error: scalar operand type has greater rank than the type of the vector element. ('half4' (vector of 4 'half' values) and 'float')
        FLOAT4 value = (FLOAT4)1.0f / sqrt(square_sum + epsilon);
                                           ~~~~~~~~~~ ^ ~~~~~~~

error: Compiler frontend failed (error code 63)
为此将603算子旁路到CPU进行执行,修改如下:
diff --git a/source/backend/opencl/core/OpenCLBackend.cpp b/source/backend/opencl/core/OpenCLBackend.cpp
index 20942e0..98a8e62 100644
--- a/source/backend/opencl/core/OpenCLBackend.cpp
+++ b/source/backend/opencl/core/OpenCLBackend.cpp
@@ -406,7 +406,7 @@ Execution* OpenCLBackend::onCreate(const std::vector<Tensor*>& inputs, const std
     auto creators = gCreator();
     auto iter      = creators->find(std::make_pair(op->type(), mOpenCLRuntime->getGpuMemType()));

-    if (iter == creators->end()) {
+    if (iter == creators->end() || op->type() == 128 || op->type() == 77 || op->type() == 603) {
         #if 0//close log
         if (nullptr != op->name()) {
             MNN_PRINT("Don't support type %s memObject:%d, %s\n", EnumNameOpType(op->type()), mOpenCLRuntime->getGpuMemType(), op->name()->c_str());

和android下的mali GPU现象一致,除了LeViT(_192),目前手头有点两个闭源GPU驱动g15/6和g13表现的不一样

感觉g13会好一些

Android

Vulkan (Vim3)

大部分的transformer based SOTA模型都是以下输出:
(index: 999,  score: nan), (index: 998,  score: nan), (index: 997,  score: nan),
只有如下给出了正确结果
Creating MNN Interpreter: EMO_2M
(index: 985,  score: 9.460938), (index: 309,  score: 3.371094), (index: 308,  score: 3.232422),
Creating MNN Interpreter: EMO_5M
(index: 985,  score: 9.179688), (index: 883,  score: 2.812500), (index: 872,  score: 2.550781),

在Jetson Orin上是所有EMO模型都对的

另外还有给出错误结果的
Creating MNN Interpreter: mobilevit_xx_small
(index: 378,  score: 6.980469), (index: 388,  score: 6.609375), (index: 387,  score: 6.449219),
Creating MNN Interpreter: mobilevit_x_small
(index: 254,  score: 2.107422), (index: 245,  score: 0.705078), (index: 242,  score: 0.056641),
Creating MNN Interpreter: mobilevit_small
(index: 306,  score: 2.326172), (index: 114,  score: 1.714844), (index: 686,  score: 1.661133),

vulkan debug journey

经过了一番艰苦卓绝的努力,终于定位到问题了
diff --git a/efficientformer_v2.py b/efficientformer_v2.py
index 2ee5d2c..eaf2d7c 100644
--- a/efficientformer_v2.py
+++ b/efficientformer_v2.py
@@ -377,6 +377,7 @@ class Mlp(nn.Module):
                 nn.init.constant_(m.bias, 0)

     def forward(self, x):
+        return x
         x = self.fc1(x)
         x = self.norm1(x)
         x = self.act(x)
@@ -447,6 +448,8 @@ class FFN(nn.Module):
                 layer_scale_init_value * torch.ones(dim).unsqueeze(-1).unsqueeze(-1), requires_grad=True)

     def forward(self, x):
+        x = self.layer_scale_2 * self.mlp(x)
+        return x
         if self.use_layer_scale:
             x = x + self.drop_path(self.layer_scale_2 * self.mlp(x))
         else:
@@ -621,7 +624,8 @@ class EfficientFormerV2(nn.Module):
     def forward_tokens(self, x):
         outs = []
         for idx, block in enumerate(self.network):
-            x = block(x)
+            x = block[0](x)
+            break
             if self.fork_feat and idx in self.out_indices:
                 norm_layer = getattr(self, f'norm{idx}')
                 x_out = norm_layer(x)
@@ -631,8 +635,8 @@ class EfficientFormerV2(nn.Module):
         return x

     def forward(self, x):
-        x = self.patch_embed(x)
         x = self.forward_tokens(x)
+        return x
         if self.fork_feat:
             # otuput features of four stages for dense prediction
             return x
python convert.py --only-convert efficientformerv2_s0 --format onnx --debug 32,56
diff --git a/source/backend/vulkan/image/backend/VulkanBackend.cpp b/source/backend/vulkan/image/backend/VulkanBackend.cpp
index 4587af7..33547b4 100644
--- a/source/backend/vulkan/image/backend/VulkanBackend.cpp
+++ b/source/backend/vulkan/image/backend/VulkanBackend.cpp
@@ -309,6 +309,8 @@ void VulkanBackend::onCopyBuffer(const Tensor* srcTensor, const Tensor* dstTenso
             iter = mConverters.find(key);
         }
         mCmdBuffers.push_back(iter->second.second->get());
+        _finish();
+        mHostBuffer = nullptr;
     } else if (dstTensor->host<void>() != nullptr) {
         // gpu->host
         auto size = VulkanTensor::getAlignSize(srcTensor) * sizeof(float);
然后mobilevitv2就跑通了,前后的性能对比如下,可以发现性能还是略有损失的。
### before
$ ./mnn_perf --only-test=mobilevitv --backend=v
Creating MNN Interpreter: mobilevitv2_050
parsed /proc/cpuinfo Hardware = "Amlogic"
The device support i8sdot:0, support fp16:0, support i8mm: 0
(index: 999,  score: nan), (index: 998,  score: nan), (index: 997,  score: nan),
[82 iters] min = 181.93ms max = 287.01ms median = 261.80ms mean = 244.07ms mean = 244.52ms
Creating MNN Interpreter: mobilevitv2_075
(index: 999,  score: nan), (index: 998,  score: nan), (index: 997,  score: nan),
[53 iters] min = 304.05ms max = 465.72ms median = 378.59ms mean = 382.60ms mean = 383.11ms
Creating MNN Interpreter: mobilevitv2_100
(index: 999,  score: nan), (index: 998,  score: nan), (index: 997,  score: nan),
[40 iters] min = 454.53ms max = 556.97ms median = 525.17ms mean = 500.00ms mean = 500.48ms
Creating MNN Interpreter: mobilevitv2_125
(index: 999,  score: nan), (index: 998,  score: nan), (index: 997,  score: nan),
[32 iters] min = 602.33ms max = 674.91ms median = 641.82ms mean = 640.74ms mean = 641.18ms
Creating MNN Interpreter: mobilevitv2_150
(index: 999,  score: nan), (index: 998,  score: nan), (index: 997,  score: nan),
[23 iters] min = 839.84ms max = 917.32ms median = 882.97ms mean = 883.69ms mean = 884.11ms
Creating MNN Interpreter: mobilevitv2_175
(index: 999,  score: nan), (index: 998,  score: nan), (index: 997,  score: nan),
[19 iters] min =1069.28ms max =1163.13ms median =1070.94ms mean =1089.48ms mean =1089.88ms
Creating MNN Interpreter: mobilevitv2_200
(index: 999,  score: nan), (index: 998,  score: nan), (index: 997,  score: nan),
[16 iters] min =1282.04ms max =1338.45ms median =1308.50ms mean =1309.95ms mean =1310.36ms
-----------------------------------------------------------------------------------------------
### after
$ ./mnn_perf --only-test=mobilevitv --backend=v
Creating MNN Interpreter: mobilevitv2_050
parsed /proc/cpuinfo Hardware = "Amlogic"
The device support i8sdot:0, support fp16:0, support i8mm: 0
(index: 985,  score: 8.414062), (index: 309,  score: 2.667969), (index: 89,  score: 2.441406),
[79 iters] min = 186.94ms max = 310.02ms median = 266.11ms mean = 250.80ms mean = 253.91ms
Creating MNN Interpreter: mobilevitv2_075
(index: 985,  score: 8.257812), (index: 309,  score: 2.695312), (index: 308,  score: 2.117188),
[52 iters] min = 306.09ms max = 468.13ms median = 380.91ms mean = 384.01ms mean = 386.83ms
Creating MNN Interpreter: mobilevitv2_100
(index: 985,  score: 8.226562), (index: 557,  score: 2.294922), (index: 309,  score: 2.080078),
[40 iters] min = 450.82ms max = 562.04ms median = 507.25ms mean = 500.48ms mean = 503.14ms
Creating MNN Interpreter: mobilevitv2_125
(index: 985,  score: 8.453125), (index: 309,  score: 2.093750), (index: 113,  score: 1.363281),
[31 iters] min = 619.29ms max = 674.50ms median = 664.89ms mean = 648.03ms mean = 650.65ms
Creating MNN Interpreter: mobilevitv2_150
(index: 985,  score: 8.703125), (index: 818,  score: 2.304688), (index: 308,  score: 2.257812),
[23 iters] min = 869.48ms max = 921.55ms median = 896.59ms mean = 897.98ms mean = 900.47ms
Creating MNN Interpreter: mobilevitv2_175
(index: 985,  score: 8.640625), (index: 494,  score: 1.972656), (index: 968,  score: 1.944336),
[19 iters] min =1071.82ms max =1135.92ms median =1135.92ms mean =1098.78ms mean =1101.16ms
Creating MNN Interpreter: mobilevitv2_200
(index: 985,  score: 8.664062), (index: 309,  score: 2.351562), (index: 308,  score: 2.244141),
[16 iters] min =1301.85ms max =1368.70ms median =1324.45ms mean =1322.81ms mean =1325.30ms
接着在跟踪LeViT的时候发现有如下错误的情况

image

如上图,x.view(*x.shape[:2], -1)【右】的功能和x.flatten(2)【左】是一致的,但是明显x.view(*x.shape[:2], -1)生成的实际网络要简单很多

diff --git a/levit.py b/levit.py
index 7fa515d..4f6eb7f 100644
--- a/levit.py
+++ b/levit.py
@@ -198,6 +198,7 @@ class Residual(torch.nn.Module):
         self.drop = drop

     def forward(self, x):
+        return self.m(x)
         if self.training and self.drop > 0:
             return x + self.m(x) * torch.rand(x.size(0), 1, 1,
                                               device=x.device).ge_(self.drop).div(1 - self.drop).detach()
@@ -256,9 +257,10 @@ class Attention(torch.nn.Module):

     def forward(self, x):  # x (B,N,C)
         B, N, C = x.shape
-        qkv = self.qkv(x)
+        qkv = x
         q, k, v = qkv.view(B, N, self.num_heads, -
                            1).split([self.key_dim, self.key_dim, self.d], dim=3)
+        return q
         q = q.permute(0, 2, 1, 3)
         k = k.permute(0, 2, 1, 3)
         v = v.permute(0, 2, 1, 3)
@@ -464,9 +466,10 @@ class LeViT(torch.nn.Module):
         return {x for x in self.state_dict().keys() if 'attention_biases' in x}

     def forward(self, x):
-        x = self.patch_embed(x)
-        x = x.flatten(2).transpose(1, 2)
-        x = self.blocks(x)
+        # x = x.flatten(2)
+        x = x.view(*x.shape[:2], -1)
+        x = self.blocks[0](x)
+        return x
         x = x.mean(1)
         if self.distillation:
             x = self.head(x), self.head_dist(x)

问题的描述参见 https://github.com/alibaba/MNN/issues/2562

用pytorch来描述是这样的:
import torch
import torch.nn as nn
x = torch.randn(1*16*16*16)
for i in range(4096): x[i] = i
q,k,v = x.view(1,16,4,64).split([16,16,32], dim=3)
print(q[0][-1][-1][-2], q[0][-1][-1][-1])
> tensor(4046.) tensor(4047.)
# get model
python convert.py --only-convert LeViT_128S --debug 16,16
cd mnn; ./MNNConvert -f ONNX --modelFile ../onnx/debug.onnx --MNNModel debug.mnn --bizCode MNN
# get debug exe
g++ -g -DDEBUG -DDEBUG_C=4096 -o mnn_perf-debug mnn_perf.cpp utils.cpp  -std=c++17 -I$MNN_INC -L$MNN_LIB -lMNN `pkg-config --cflags --libs opencv4`

如果用CPU后端能够得到正确的结果:

$ ./mnn_perf-test --only-test=debug
Creating MNN Interpreter: debug
parsed /proc/cpuinfo Hardware = "Amlogic"
The device support i8sdot:0, support fp16:0, support i8mm: 0
[len: 1024] (0: 0) (1: 1) (-2:4046) (-1:4047)

但是用OpenCL和Vulkan后端就是错误的结果:

$ ./mnn_perf-test --only-test=debug --backend=o
Creating MNN Interpreter: debug
parsed /proc/cpuinfo Hardware = "Amlogic"
The device support i8sdot:0, support fp16:0, support i8mm: 0
[len: 1024] (0: 0) (1: 1) (-2:4046) (-1:4048)
$ ./mnn_perf-test --only-test=debug --backend=v
Creating MNN Interpreter: debug
parsed /proc/cpuinfo Hardware = "Amlogic"
The device support i8sdot:0, support fp16:0, support i8mm: 0
[len: 1024] (0: 0) (1: 1) (-2:4046) (-1:4048)

OpenGL由于不能用GPU处理128算子,回退到了CPU执行,所以结果也是对的

$ ./mnn_perf-test --only-test=debug --backend=g
Creating MNN Interpreter: debug
parsed /proc/cpuinfo Hardware = "Amlogic"
The device support i8sdot:0, support fp16:0, support i8mm: 0
gpu type : Mali-G52
gl version : OpenGL ES 3.2 v1.r16p0-01rel0-32c0b2f.a081bcdde22f85278903af76f9cc56e9
gpu type : Mali-G52
gl version : OpenGL ES 3.2 v1.r16p0-01rel0-32c0b2f.a081bcdde22f85278903af76f9cc56e9
Don't support type 128
Don't support type 128
Don't support type 128
[len: 1024] (0: 0) (1: 1) (-2:4046) (-1:4047)

此外,并不是所有情况都是错误的,如果我们把输入规模减小到4*256,结果如下

python convert.py --only-convert LeViT_128S --debug 4,16
cd mnn; ./MNNConvert -f ONNX --modelFile ../onnx/debug.onnx --MNNModel debug.mnn --bizCode MNN
g++ -g -DDEBUG -DDEBUG_C=4096 -o mnn_perf-debug mnn_perf.cpp utils.cpp  -std=c++17 -I$MNN_INC -L$MNN_LIB -lMNN `pkg-config --cflags --libs opencv4`
$ ./mnn_perf-test --only-test=debug
Creating MNN Interpreter: debug
parsed /proc/cpuinfo Hardware = "Amlogic"
The device support i8sdot:0, support fp16:0, support i8mm: 0
[len: 256] (0: 0) (1: 1) (-2:974) (-1:975)
$ ./mnn_perf-test --only-test=debug --backend=v
Creating MNN Interpreter: debug
parsed /proc/cpuinfo Hardware = "Amlogic"
The device support i8sdot:0, support fp16:0, support i8mm: 0
[len: 256] (0: 0) (1: 1) (-2:974) (-1:975)
$ ./mnn_perf-test --only-test=debug --backend=o
Creating MNN Interpreter: debug
parsed /proc/cpuinfo Hardware = "Amlogic"
The device support i8sdot:0, support fp16:0, support i8mm: 0
[len: 256] (0: 0) (1: 1) (-2:974) (-1:975)
那就把128算子回退到CPU上执行,
diff --git a/source/backend/vulkan/image/backend/VulkanBackend.cpp b/source/backend/vulkan/image/backend/VulkanBackend.cpp
index 33547b4..c38cbb2 100644
--- a/source/backend/vulkan/image/backend/VulkanBackend.cpp
+++ b/source/backend/vulkan/image/backend/VulkanBackend.cpp
@@ -164,7 +164,7 @@ Execution* VulkanBackend::onCreate(const std::vector<Tensor*>& inputs, const std
     if (nullptr != op->name()) {
         name = op->name()->str();
     }
-    if (iter == creator->end()) {
+    if (iter == creator->end() || op->type() == 128) {
 #ifdef MNN_OP_SUPPORT_LOG
         MNN_PRINT("Vulkan don't support %d, %s: %s\n", op->type(), EnumNameOpType(op->type()),
                 name.c_str());

可以得到:

$ ./mnn_perf-test --only-test=Le --backend=v
Creating MNN Interpreter: LeViT_128S
parsed /proc/cpuinfo Hardware = "Amlogic"
The device support i8sdot:0, support fp16:0, support i8mm: 0
(index: 999,  score: 6.109375), (index: 985,  score: 5.468750), (index: 716,  score: 5.066406),
Creating MNN Interpreter: LeViT_128
(index: 985,  score: 7.898438), (index: 465,  score: 6.152344), (index: 968,  score: 5.843750),
Creating MNN Interpreter: LeViT_192
(index: 985,  score: 8.898438), (index: 947,  score: 6.449219), (index: 992,  score: 5.289062),
Creating MNN Interpreter: LeViT_256
(index: 879,  score: 8.710938), (index: 112,  score: 7.492188), (index: 999,  score: 6.875000),

能给出具体数值了,但是结果还是不对;但是这里有一个意外之喜,那就是mobilevit能够跑对了,而且这么一改反而性能还提高了一点,神奇!

########## before
$ ./mnn_perf --only-test mobilevit_ --backend=v
Creating MNN Interpreter: mobilevit_xx_small
parsed /proc/cpuinfo Hardware = "Amlogic"
The device support i8sdot:0, support fp16:0, support i8mm: 0
(index: 378,  score: 6.980469), (index: 388,  score: 6.609375), (index: 387,  score: 6.449219),
[37 iters] min = 423.82ms max = 618.31ms median = 468.95ms mean = 547.39ms mean = 551.33ms
Creating MNN Interpreter: mobilevit_x_small
(index: 254,  score: 2.107422), (index: 245,  score: 0.705078), (index: 242,  score: 0.056641),
[38 iters] min = 450.06ms max = 709.02ms median = 479.54ms mean = 526.55ms mean = 529.76ms
Creating MNN Interpreter: mobilevit_small
(index: 306,  score: 2.326172), (index: 114,  score: 1.714844), (index: 686,  score: 1.661133),
[30 iters] min = 574.93ms max = 877.98ms median = 626.55ms mean = 670.52ms mean = 673.68ms
----------------------------------------------------------------------------------------------
########## after
$ ./mnn_perf --only-test mobilevit_ --backend=v
Creating MNN Interpreter: mobilevit_xx_small
parsed /proc/cpuinfo Hardware = "Amlogic"
The device support i8sdot:0, support fp16:0, support i8mm: 0
(index: 985,  score: 12.460938), (index: 309,  score: 6.507812), (index: 308,  score: 6.222656),
[37 iters] min = 318.21ms max = 617.63ms median = 570.81ms mean = 542.69ms mean = 546.55ms
Creating MNN Interpreter: mobilevit_x_small
(index: 985,  score: 13.070312), (index: 89,  score: 6.812500), (index: 309,  score: 5.839844),
[40 iters] min = 377.96ms max = 665.04ms median = 597.75ms mean = 499.82ms mean = 501.56ms
Creating MNN Interpreter: mobilevit_small
(index: 985,  score: 10.468750), (index: 309,  score: 3.750000), (index: 838,  score: 3.707031),
[31 iters] min = 584.63ms max = 750.87ms median = 750.87ms mean = 650.45ms mean = 652.12ms

这一改动在Jetson orin上使得以下模型能够正常运行

  • SwiftFormer
  • mobilevit
  • edgenext
在Jetson orin上运行这三个模型目前以后的正确修改前后的性能对比:

image

所以接着调试,又出现了一个低级bug
diff --git a/levit.py b/levit.py
index 7fa515d..1010191 100644
--- a/levit.py
+++ b/levit.py
@@ -198,6 +198,7 @@ class Residual(torch.nn.Module):
         self.drop = drop

     def forward(self, x):
+        return self.m(x)
         if self.training and self.drop > 0:
             return x + self.m(x) * torch.rand(x.size(0), 1, 1,
                                               device=x.device).ge_(self.drop).div(1 - self.drop).detach()
@@ -255,6 +256,7 @@ class Attention(torch.nn.Module):
             self.ab = self.attention_biases[:, self.attention_bias_idxs]

     def forward(self, x):  # x (B,N,C)
+        """
         B, N, C = x.shape
         qkv = self.qkv(x)
         q, k, v = qkv.view(B, N, self.num_heads, -
@@ -262,13 +264,17 @@ class Attention(torch.nn.Module):
         q = q.permute(0, 2, 1, 3)
         k = k.permute(0, 2, 1, 3)
         v = v.permute(0, 2, 1, 3)
+        """

+        x = x + self.ab
+        return x
         attn = (
             (q @ k.transpose(-2, -1)) * self.scale
             +
             (self.attention_biases[:, self.attention_bias_idxs]
              if self.training else self.ab)
         )
+
         attn = attn.softmax(dim=-1)
         x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
         x = self.proj(x)
@@ -464,9 +470,8 @@ class LeViT(torch.nn.Module):
         return {x for x in self.state_dict().keys() if 'attention_biases' in x}

     def forward(self, x):
-        x = self.patch_embed(x)
-        x = x.flatten(2).transpose(1, 2)
-        x = self.blocks(x)
+        x = self.blocks[0](x)
+        return x
         x = x.mean(1)
         if self.distillation:
             x = self.head(x), self.head_dist(x)

python convert.py --only-convert LeViT_128S --debug 4,196

然后就能得到MNN网络模型了: image

bug描述参见 https://github.com/alibaba/MNN/issues/2562

目前的解决方案是将其back到CPU执行:
diff --git a/source/backend/vulkan/image/execution/VulkanBinary.cpp b/source/backend/vulkan/image/execution/VulkanBinary.cpp
index b3f9eba..f68189e 100644
--- a/source/backend/vulkan/image/execution/VulkanBinary.cpp
+++ b/source/backend/vulkan/image/execution/VulkanBinary.cpp
@@ -170,6 +170,16 @@ public:
     virtual VulkanBasicExecution* onCreate(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs, const MNN::Op* op,
                                 Backend* backend) const override {
         auto input0 = inputs[0];
+        for (int i = 1; i < inputs.size(); ++i) {
+            auto input = inputs[i];
+           if (input0->dimensions()  + input->dimensions() == 7) return nullptr;
+           /*
+            if (input0->dimensions() != input->dimensions()) {
+                MNN_PRINT("dimensions : [%d, %d] \n", input0->dimensions(), input->dimensions());
+                MNN_PRINT("vulkan binary don't support broatcast !!! \n");
+            }
+           */
+        }
         auto image = TensorUtils::getDescribe(input0)->dimensionFormat == MNN_DATA_FORMAT_NC4HW4;
         auto shader = _getShaderName(op, image);
         if (shader.empty()) {
这样一来,LeViT终于跑通了,前后到性能对比如下:
########### before
$ ./mnn_perf --only-test Le --backend=v
Creating MNN Interpreter: LeViT_128S
parsed /proc/cpuinfo Hardware = "Amlogic"
The device support i8sdot:0, support fp16:0, support i8mm: 0
(index: 999,  score: nan), (index: 998,  score: nan), (index: 997,  score: nan),
[92 iters] min = 146.31ms max = 235.12ms median = 222.81ms mean = 214.94ms mean = 218.09ms
Creating MNN Interpreter: LeViT_128
(index: 999,  score: nan), (index: 998,  score: nan), (index: 997,  score: nan),
[66 iters] min = 184.64ms max = 338.27ms median = 328.94ms mean = 301.46ms mean = 304.59ms
Creating MNN Interpreter: LeViT_192
(index: 999,  score: nan), (index: 998,  score: nan), (index: 997,  score: nan),
[60 iters] min = 241.25ms max = 422.28ms median = 388.10ms mean = 333.36ms mean = 336.37ms
Creating MNN Interpreter: LeViT_256
(index: 999,  score: nan), (index: 998,  score: nan), (index: 997,  score: nan),
[50 iters] min = 280.61ms max = 540.58ms median = 420.44ms mean = 401.42ms mean = 404.50ms
------------------------------------------------------------------------------------------
########### after
$ ./mnn_perf --only-test Le --backend=v
Creating MNN Interpreter: LeViT_128S
parsed /proc/cpuinfo Hardware = "Amlogic"
The device support i8sdot:0, support fp16:0, support i8mm: 0
(index: 985,  score: 11.703125), (index: 308,  score: 3.591797), (index: 309,  score: 3.382812),
[46 iters] min = 336.30ms max = 477.36ms median = 459.20ms mean = 437.99ms mean = 441.49ms
Creating MNN Interpreter: LeViT_128
(index: 985,  score: 11.343750), (index: 309,  score: 3.447266), (index: 113,  score: 3.291016),
[33 iters] min = 506.19ms max = 671.31ms median = 641.24ms mean = 613.40ms mean = 617.43ms
Creating MNN Interpreter: LeViT_192
(index: 985,  score: 11.781250), (index: 324,  score: 3.406250), (index: 326,  score: 3.328125),
[28 iters] min = 569.09ms max = 791.01ms median = 756.90ms mean = 733.99ms mean = 737.39ms
Creating MNN Interpreter: LeViT_256
(index: 985,  score: 11.195312), (index: 108,  score: 3.046875), (index: 309,  score: 2.949219),
[23 iters] min = 728.85ms max = 957.64ms median = 793.62ms mean = 893.77ms mean = 895.65ms

可以看到对性能的影响还是很大的。

但此时在Vim3上还有大片的网络模型输出结果是nan,为了该清楚这个问题,就要继续debug下去,利用efficientformerv2缩减,

构造我们的测试网络,如下图 image

非常的简单

diff --git a/efficientformer_v2.py b/efficientformer_v2.py
index 2ee5d2c..f2df920 100644
--- a/efficientformer_v2.py
+++ b/efficientformer_v2.py
@@ -377,14 +377,17 @@ class Mlp(nn.Module):
                 nn.init.constant_(m.bias, 0)

     def forward(self, x):
+        """
         x = self.fc1(x)
         x = self.norm1(x)
         x = self.act(x)
+        """

         if self.mid_conv:
-            x_mid = self.mid(x)
-            x_mid = self.mid_norm(x_mid)
+            #x_mid = self.mid(x)
+            x_mid = self.mid_norm(x)
             x = self.act(x_mid)
+        return x
         x = self.drop(x)

         x = self.fc2(x)
@@ -447,6 +450,7 @@ class FFN(nn.Module):
                 layer_scale_init_value * torch.ones(dim).unsqueeze(-1).unsqueeze(-1), requires_grad=True)

     def forward(self, x):
+        return self.mlp(x)
         if self.use_layer_scale:
             x = x + self.drop_path(self.layer_scale_2 * self.mlp(x))
         else:
@@ -621,7 +625,8 @@ class EfficientFormerV2(nn.Module):
     def forward_tokens(self, x):
         outs = []
         for idx, block in enumerate(self.network):
-            x = block(x)
+            x = block[0](x)
+            break
             if self.fork_feat and idx in self.out_indices:
                 norm_layer = getattr(self, f'norm{idx}')
                 x_out = norm_layer(x)
@@ -631,8 +636,9 @@ class EfficientFormerV2(nn.Module):
         return x

     def forward(self, x):
-        x = self.patch_embed(x)
+        # x = self.patch_embed(x)
         x = self.forward_tokens(x)
+        return x
         if self.fork_feat:
             # otuput features of four stages for dense prediction
             return x
python convert.py --only-convert efficientformerv2_s0 --debug 128,56

通过第一层batchnorm,可以得到结果[len: 401408] (0: -72.4048) (1: -72.4048) (-2:3.8044) (-1:3.8044)

接着通过第二层gelu,问题就出现了,正确的结果应该是[len: 401408] (0: -0) (1: -0) (-2:3.80078) (-1:3.80078);orin的cpu和vulkan后端以及Vim3的CPU都给出了相同的正确结果,但是唯独Vim3的vulkan给出了[len: 401408] (0: nan) (1: nan) (-2:3.80469) (-1:3.80469)的怪诞结果。

为此我们有以下结论:

  1. 应该是Vim3 android Vulkan驱动的bug,因为不开源,所以我无能为力
  2. 这也解释了EMO系列为什么部分模型能运行正确但是另外的就是nan值,就是因为生成nan值的模型在batchnorm时产生了过大的负数!

把仓库更新到最新之后,一切都变得好了起来!!!

  • efficientformerv2明显的精度误差
  • LeViT巨大的精度误差,不能用!

需要用到之前的两个patch,来修正精度!

OpenGL (Vim3)

pkg install mesa-dev
事实证明 OpenGL没多大卵用,不要也罢。。。

image

emmmmm..... /system/lib64/libz.so有毛病。。。,用以下命令来使用termux的下面的库:

export LD_PRELOAD=/data/data/com.termux/files/usr/lib/libz.so

就可以继续运行下去了,but emmmm......:

Creating MNN Interpreter: resnet50
parsed /proc/cpuinfo Hardware = "Amlogic"
The device support i8sdot:0, support fp16:0, support i8mm: 0
gpu type : Mali-G52
gl version : OpenGL ES 3.2 v1.r16p0-01rel0-32c0b2f.a081bcdde22f85278903af76f9cc56e9
gpu type : Mali-G52
gl version : OpenGL ES 3.2 v1.r16p0-01rel0-32c0b2f.a081bcdde22f85278903af76f9cc56e9
Don't support type 112, /global_pool/pool/GlobalAveragePool_output_0
Don't support type [Pooling3D], /global_pool/pool/GlobalAveragePool_output_0
Create exection error : 112
Can't run session because not resized
(index: 999,  score: 0.000000), (index: 998,  score: 0.000000), (index: 997,  score: 0.000000),

MNN OpenGL backend is no longer maintained, it is recommended to use opencl / vulkan. The reason for this error is that there is no OpenGL backend and no adaptation geometry calculation. Seeing that the driver does not support it, I can only run mnn vulkan anymore

经过了Linux上的debug journey之后,很快修复了问题

diff --git a/source/backend/opengl/GLBackend.cpp b/source/backend/opengl/GLBackend.cpp
index b091237..737374e 100644
--- a/source/backend/opengl/GLBackend.cpp
+++ b/source/backend/opengl/GLBackend.cpp
@@ -469,6 +469,7 @@ int GLRuntime::onGetRuntimeStatus(RuntimeStatus statusEnum) const {
 }

 Runtime::CompilerType GLRuntime::onGetCompilerType() const {
+    return Compiler_Loop;
     return Compiler_Origin;
 }

然后像resnet这种陈旧经典的模型就可以正常运行了!并给出了正确结果!

$ ./mnn_perf --only-test=res --backend=g
Creating MNN Interpreter: resnet50
parsed /proc/cpuinfo Hardware = "Amlogic"
The device support i8sdot:0, support fp16:0, support i8mm: 0
gpu type : Mali-G52
gl version : OpenGL ES 3.2 v1.r16p0-01rel0-32c0b2f.a081bcdde22f85278903af76f9cc56e9
gpu type : Mali-G52
gl version : OpenGL ES 3.2 v1.r16p0-01rel0-32c0b2f.a081bcdde22f85278903af76f9cc56e9
Don't support type 128
Don't support type 128
(index: 985,  score: 8.000000), (index: 113,  score: -5.250000), (index: 307,  score: -5.468750),
[53 iters] min = 382.15ms max = 386.08ms median = 384.08ms mean = 383.91ms mean = 384.47ms

OpenCL

pkg install clinfo opencl-vendor-driver

OpenCL (SDM845) debug journey

只有mobilevitv2这么一个错的离谱!原因在603算子? 具体的分析参见 https://github.com/alibaba/MNN/issues/2574

得到debug模型的方式是魔改mobilevit_v2.py文件:
diff --git a/mobilevit_v2.py b/mobilevit_v2.py
index d82ff06..de1addd 100644
--- a/mobilevit_v2.py
+++ b/mobilevit_v2.py
@@ -133,6 +133,7 @@ class LinearAttnFFN(nn.Module):
         )

     def forward(self, x: Tensor) -> Tensor:
+        return self.pre_norm_attn[0](x)
         # self-attention
         x = x + self.pre_norm_attn(x)
         # Feed forward network
@@ -344,7 +345,7 @@ class MobileViTBlockv2(nn.Module):
         return x

     def forward(self, x: Tensor) -> Tensor:
-        x = self.resize_input_if_needed(x)
+        # x = self.resize_input_if_needed(x)

         fm = self.local_rep(x)

@@ -355,7 +356,9 @@ class MobileViTBlockv2(nn.Module):
             patches, output_size = self.unfolding_pytorch(fm)

         # learn global representations on all patches
-        patches = self.global_rep(patches)
+        ## [1, 128, 4, 16]
+        patches = self.global_rep[0](patches)
+        return patches

         # [B x Patch x Patches x C] --> [B x C x Patches x Patch]
         if self.coreml:
@@ -507,13 +510,16 @@ class MobileViTv2(nn.Module):
             pass

     def forward(self, x: Tensor) -> Tensor:
+        """
         x = self.conv_1(x)
         x = self.layer_1(x)
         x = self.layer_2(x)

         x = self.layer_3(x)
         x = self.layer_4(x)
-        x = self.layer_5(x)
+        """
+        x = self.layer_5[1](x)
+        return x
         x = self.classifier(x)
         return x
python convert.py --only-convert mobilevitv2_050 --debug 256,8

感觉是603算子的问题,如果将603旁路到CPU执行,性能的前后对比为

###### before
$ LD_LIBRARY_PATH=/vendor/lib64:$MNN_LIB ./mnn_perf --backend=o --only-test=mobilevitv
Creating MNN Interpreter: mobilevitv2_050
parsed /proc/cpuinfo Hardware = "Qualcomm Technologies, Inc SDM845"
The device support i8sdot:0, support fp16:1, support i8mm: 0
(index: 999,  score: nan), (index: 998,  score: nan), (index: 997,  score: nan),
[212 iters] min =  86.45ms max = 107.87ms median =  92.85ms mean =  93.87ms mean =  94.73ms
Creating MNN Interpreter: mobilevitv2_075
(index: 999,  score: nan), (index: 998,  score: nan), (index: 997,  score: nan),
[156 iters] min = 120.88ms max = 137.40ms median = 130.35ms mean = 127.31ms mean = 128.30ms
Creating MNN Interpreter: mobilevitv2_100
(index: 999,  score: nan), (index: 998,  score: nan), (index: 997,  score: nan),
[109 iters] min = 175.54ms max = 194.17ms median = 183.11ms mean = 182.26ms mean = 183.65ms
Creating MNN Interpreter: mobilevitv2_125
(index: 985,  score: 8.460938), (index: 309,  score: 2.080078), (index: 113,  score: 1.415039),
[76 iters] min = 258.05ms max = 274.18ms median = 268.46ms mean = 264.31ms mean = 265.22ms
Creating MNN Interpreter: mobilevitv2_150
(index: 985,  score: 8.968750), (index: 308,  score: 2.263672), (index: 301,  score: 2.218750),
[61 iters] min = 323.63ms max = 341.79ms median = 327.66ms mean = 329.72ms mean = 331.22ms
Creating MNN Interpreter: mobilevitv2_175
(index: 985,  score: 8.890625), (index: 494,  score: 2.076172), (index: 309,  score: 1.885742),
[51 iters] min = 389.33ms max = 404.03ms median = 397.05ms mean = 396.92ms mean = 398.61ms
Creating MNN Interpreter: mobilevitv2_200
(index: 985,  score: 8.570312), (index: 309,  score: 2.208984), (index: 308,  score: 2.171875),
[43 iters] min = 462.33ms max = 479.63ms median = 476.62ms mean = 472.68ms mean = 474.33ms
-----------------------------------------------------------------------------------------------
########### after
$ LD_LIBRARY_PATH=/vendor/lib64:$MNN_LIB ./mnn_perf --backend=o --only-test=mobilevitv
Creating MNN Interpreter: mobilevitv2_050
parsed /proc/cpuinfo Hardware = "Qualcomm Technologies, Inc SDM845"
The device support i8sdot:0, support fp16:1, support i8mm: 0
(index: 985,  score: 8.429688), (index: 309,  score: 2.652344), (index: 89,  score: 2.478516),
[200 iters] min =  92.66ms max = 113.57ms median =  97.50ms mean =  99.32ms mean = 100.39ms
Creating MNN Interpreter: mobilevitv2_075
(index: 985,  score: 8.281250), (index: 309,  score: 2.701172), (index: 308,  score: 2.132812),
[148 iters] min = 127.37ms max = 146.59ms median = 135.61ms mean = 134.09ms mean = 135.20ms
Creating MNN Interpreter: mobilevitv2_100
(index: 985,  score: 8.242188), (index: 557,  score: 2.316406), (index: 309,  score: 2.091797),
[106 iters] min = 181.88ms max = 201.91ms median = 183.70ms mean = 187.88ms mean = 189.08ms
Creating MNN Interpreter: mobilevitv2_125
(index: 985,  score: 8.460938), (index: 309,  score: 2.080078), (index: 113,  score: 1.415039),
[76 iters] min = 257.35ms max = 276.97ms median = 259.67ms mean = 264.50ms mean = 265.42ms
Creating MNN Interpreter: mobilevitv2_150
(index: 985,  score: 8.968750), (index: 308,  score: 2.263672), (index: 301,  score: 2.218750),
[61 iters] min = 318.29ms max = 334.91ms median = 324.95ms mean = 327.61ms mean = 329.07ms
Creating MNN Interpreter: mobilevitv2_175
(index: 985,  score: 8.890625), (index: 494,  score: 2.076172), (index: 309,  score: 1.885742),
[51 iters] min = 387.83ms max = 405.32ms median = 394.27ms mean = 397.02ms mean = 398.54ms
Creating MNN Interpreter: mobilevitv2_200
(index: 985,  score: 8.570312), (index: 309,  score: 2.208984), (index: 308,  score: 2.171875),
[43 iters] min = 464.43ms max = 484.69ms median = 475.38ms mean = 474.47ms mean = 476.11ms

OpenCL (Edge2 & Exynos 990) debug journey

  • mobilevit ✅

无需修改,即可运行正确,softmax not support dimensions == 3

  • SwiftFormer ✅
  • EMO ✅

现象:

Map error scalePtrCL == nullptr
Map error biasPtrCL == nullptr
(index: 999,  score: nan), (index: 998,  score: nan), (index: 997,  score: nan),

需要将77算子bypass到CPU执行才能跑对!

  • mobilevitv2 ✅

现象:(index: 999, score: nan), (index: 998, score: nan), (index: 997, score: nan),

需要将603算子bypass到CPU执行才能跑对!

  • edgenext, 存在较大精度误差,通过简单的bypass手段不能解决精度误差大的问题,在测试集上会掉4个点 ❓
  • efficientformerv2, 77算子CPU | 精度无法弥补,而且精度惨不忍睹!(0.4760, 0.6780, 0.7370) ❓❓
  • LeViT, 77算子CPU | 存在巨大的精度误差,验证集的准确率不到10% ... ...❓❓❓

precision

reference

-D MNN_BUILD_QUANTOOLS=ON

表中标*的模型表示量化时跳过了第一层或者前两层不量化,以提升精度。

目前INT8量化实现放在单独的文件夹中,并且使用单独的宏隔离开。用户需要使用INT8量化推理的话,需要在编译阶段手动开启MNN_CUDA_QUANT宏即可编译cuda int8的源码。

新增ARMv8.6指令支持后,GemmInt8, GemmBF16性能提升

BF16 可以给中低端手机和高端机的小核带来性能收益,并且降低内存占用。经MNN团队内部测试,BF16相对于FP32在不同机型的中低端核心(A53 A55 A53kyro A55kyro)上,不同模型有 5%- 30%的优化。

目前fp16和int8在支持ARMv82的设备上使用sdot加速能够带来超过fp32的性能;

简介: ARMv86指令集新增了通用矩阵乘指令与bf16的支持,这些指令理论性能是ARMv82sdot的2倍;使用这些指令实现int8/bf16矩阵乘能够带来显著的性能提升。本文使用ARMv86的新增指令对MNN的ConvInt8和MatMul算子进行实现,最高得到了大约90%的性能提升。

本文主要介绍加速指令为smmla与bfmmla利用smmla计算GEMM-int8,bfmmla指令来计算GEMM-bf16;这两条指令相比sdot指令,在延迟不变的情况下,计算量是sdot的2倍,因此相比sdot理论加速比为100%。

对于int8量化模型,在用户执行推理时会模型使用int8精度计算量化算子。如果设备支持ARMv86则会使用smmla指令加速的算子。对于浮点模型,在用户执行模型推理时,可以通过BackendConfig中的Precision选项来控制推理精度,选择默认精度Precision_Normal时会使用fp32进行推理,选择低精度Precision_Low时则会使用fp16进行推理。为了区分fp16与bf16,我们新增了Precision_Low_BF16 选项,当用户将精度设为此选项时,会执行bf16后端,如果设备支持ARMv86则会使用bfmmla指令加速的算子。

armv8.2:在2018年开始出现这种指令集的arm处理器,主要的arm CPU有:A76、A77、A55等。这种指令集开始支持:int8计算的dot指令,以及fp16计算的fmla指令。由于在之前的arm处理器的int8量化计算都是将int8转成int16再进行计算(防止int8*int8->int8溢出),计算过后生成一个int32,因此性能上完全没有体现出int8位宽窄的优势,但是在出现了dot指令。dot指令每个周期可以计算出16个int8的乘加操作,生成4个int32.如下图:

这个指令完全利用了int8位宽优势,还能保证不溢出。单指令的吞吐量直接比f32的情况高了4倍。 因此armv8.2之后的处理器int8量化之后性能提升很多。

result

CPU/GPU FP32/FP16 (`--fp16` storage)
Model Top-1 Top-1
//20 est.
Top-1
//50 est.
#params GMACs
efficientformerv2_s0 - 76.2 76.1 3.5M 0.40G
efficientformerv2_s1 - 78.3 79.8 6.1M 0.65G
efficientformerv2_s2 - 82.1 82.0 12.6M 1.25G
SwiftFormer_XS - 76.2 75.1 3.5M 0.4G
SwiftFormer_S - 78.4 78.2 6.1M 1.0G
SwiftFormer_L1 - 80.6 81.8 12.1M 1.6G
EMO_1M - 70.8 68.3 1.3M 0.26G
EMO_2M - 74.8 73.7 2.3M 0.44G
EMO_5M - 78.2 77.6 5.1M 0.90G
EMO_6M - 79.1 78.1 6.1M 0.96G
edgenext_xx_small - 70.8 70.7 1.3M 0.26G
edgenext_x_small - 74.8 74.7 2.3M 0.54G
edgenext_small/usi - 80.6 79.9 5.6M 1.26G
fp16上述模型结果完全错误为Nan
mobilevitv2_050 - 69.9 66.6 1.4M 0.5G
mobilevitv2_075 - 75.1 74.3 2.9M 1.0G
mobilevitv2_100 - 77.9 76.9 4.9M 1.8G
mobilevitv2_125 - 79.2 80.7 7.5M 2.8G
mobilevitv2_150 - 80.9 81.8 10.6M 4.0G
mobilevitv2_175 - 80.7 81.0 14.3M 5.5G
mobilevitv2_200 - 82.0 83.1 18.4M 7.2G
mobilevit_xx_small - 68.9 66.5 1.3M 0.36G
mobilevit_x_small - 74.0 73.7 2.3M 0.89G
mobilevit_small - 77.6 77.9 5.6M 2.0 G
LeViT_128S - 75.9 76.1 7.8M 0.30G
LeViT_128 - 79.4 78.1 9.2M 0.41G
LeViT_192 - 79.6 79.6 11 M 0.66G
LeViT_256 - 81.1 81.4 19 M 1.12G
resnet50 - 79.6 81.3 25.6M 4.1G
mobilenetv3_large_100 - 75.6 75.3 5.5M 0.29G
tf_efficientnetv2_b0 - 78.2 76.7 7.1M 0.72G
tf_efficientnetv2_b1 - 79.4 79.2 8.1M 1.2G
tf_efficientnetv2_b2 - 81.7 80.4 10.1M 1.7G
tf_efficientnetv2_b3 - 81.8 82.3 14.4M 3.0G
CPU/GPU FP32/FP16 (`--weightQuantBits 8 /[--weightQuantAsymmetric true]` storage)
Model Top-1 Top-1
//20 est.
Top-1
//50 est.
#params GMACs
efficientformerv2_s0 - 75.6/76.0 75.5/75.3 3.5M 0.40G
efficientformerv2_s1 - 78.1/78.3 79.1/78.7 6.1M 0.65G
efficientformerv2_s2 - 81.8/82.2 81.1/81.7 12.6M 1.25G
SwiftFormer_XS - 75.8/76.1 74.6/75.2 3.5M 0.4G
SwiftFormer_S - 78.0/78.3 77.8/78.3 6.1M 1.0G
SwiftFormer_L1 - 80.5/80.5 82.0/81.5 12.1M 1.6G
EMO_1M - 70.8/70.6 69.5/68.1 1.3M 0.26G
EMO_2M - 74.4/74.8 73.5/73.3 2.3M 0.44G
EMO_5M - 78.0/78.2 77.4/77.3 5.1M 0.90G
EMO_6M - 79.3/79.0 77.9/78.2 6.1M 0.96G
edgenext_xx_small - 70.7/70.9 70.8/71.2 1.3M 0.26G
edgenext_x_small - 74.8/74.8 75.0/74.8 2.3M 0.54G
edgenext_small/usi - 80.5/80.6 79.7/79.8 5.6M 1.26G
fp16上述模型结果完全错误为Nan
mobilevitv2_050 - 0 0 1.4M 0.5G
mobilevitv2_075 - 0 0 2.9M 1.0G
mobilevitv2_100 - 0 0 4.9M 1.8G
mobilevitv2_125 - 0 0 7.5M 2.8G
mobilevitv2_150 - 0 0 10.6M 4.0G
mobilevitv2_175 - 0 0 14.3M 5.5G
mobilevitv2_200 - 2 0 18.4M 7.2G
mobilevit_xx_small - 68.5/67.9 66.9/66.2 1.3M 0.36G
mobilevit_x_small - 74.0/74.0 72.6/72.9 2.3M 0.89G
mobilevit_small - 77.8/77.4 77.4/78.1 5.6M 2.0 G
LeViT_128S - 76.3/75.7 76.1/75.8 7.8M 0.30G
LeViT_128 - 78.9/79.0 77.8/77.3 9.2M 0.41G
LeViT_192 - 79.8/79.6 79.9/79.4 11 M 0.66G
LeViT_256 - 81.0/81.3 81.7/81.0 19 M 1.12G
resnet50 - 79.4/79.2 80.3/81.1 25.6M 4.1G
mobilenetv3_large_100 - 75.9/75.9 75.3/75.6 5.5M 0.29G
tf_efficientnetv2_b0 - 78.2/78.6 77.0/77.2 7.1M 0.72G
tf_efficientnetv2_b1 - 79.4/78.9 78.8/78.9 8.1M 1.2G
tf_efficientnetv2_b2 - 81.5/81.7 80.5/80.7 10.1M 1.7G
tf_efficientnetv2_b3 - 81.4/81.8 82.1/82.4 14.4M 3.0G

MISC

interesting grep log
$ grep GLES -rIn . --exclude-dir=project
./CMakeLists.txt:564:    list(APPEND MNN_EXTRA_DEPENDS GLESv3)
./schema/default/MNN.fbs:428:    OPENGLES,
./schema/current/MNN_generated.h:2548:  ForwardType_OPENGLES = 3,
./schema/current/MNN_generated.h:2559:    ForwardType_OPENGLES,
./schema/current/MNN_generated.h:2570:    "OPENGLES",
./schema/current/MNN_generated.h:7949:    "OPENGLES",
./source/backend/vulkan/vulkan/vulkan_core.h:6365:#define VK_EXT_DISCARD_RECTANGLES_SPEC_VERSION 1
./source/backend/vulkan/vulkan/vulkan_core.h:6366:#define VK_EXT_DISCARD_RECTANGLES_EXTENSION_NAME "VK_EXT_discard_rectangles"
./source/backend/opencl/core/runtime/OpenCLWrapper.cpp:33:        "libGLES_mali.so",
./source/backend/opencl/core/runtime/OpenCLWrapper.cpp:41:        "/system/vendor/lib64/egl/libGLES_mali.so",
./source/backend/opencl/core/runtime/OpenCLWrapper.cpp:42:        "/system/lib64/egl/libGLES_mali.so",
./source/backend/opencl/core/runtime/OpenCLWrapper.cpp:47:        "/system/vendor/lib/egl/libGLES_mali.so", "/system/lib/egl/libGLES_mali.so",
./source/backend/opengl/CMakeLists.txt:8:    target_link_libraries(MNN_GL MNN GLESv3 EGL)
./source/backend/opengl/GLHead.hpp:21:#include <GLES2/gl2.h>
./source/backend/opengl/GLHead.hpp:22:#include <GLES2/gl2ext.h>
./source/backend/opengl/GLHead.hpp:23:#include <GLES3/gl31.h>
./3rd_party/OpenCLHeaders/CL/cl_gl_ext.h:56: *  This allows us to avoid having to decide whether to include GL headers or GLES here.
./3rd_party/OpenCLHeaders/CL/cl_platform.h:380:/* Mirror types to GL types. Mirror types allow us to avoid deciding which 87s to load based on whether we are using GL or GLES here. */
raspberry pi 4b & khadas vim3的 Linux vulkan失败调试之路

故事起源于https://qengineering.eu/install-vulkan-on-raspberry-pi.html 这片博文,里面堂而皇之的写了MNN的Vulkan性能数据,但是到我这里就不能跑了啊,

Error for /home/albert/work/mnn-orin/source/backend/vulkan/component/VulkanDevice.cpp, 81
mnn_perf-test: /home/albert/work/mnn-orin/source/backend/vulkan/component/VulkanDevice.cpp:81: MNN::VulkanDevice::VulkanDevice(std::shared_ptr<VulkanInstance>, const std::vector<const char *> &): Assertion `res' failed.

尝试了3种解决方案,但效果么都是殊途同归的错误:

  • patch 1
diff --git a/source/backend/vulkan/component/VulkanDevice.cpp b/source/backend/vulkan/component/VulkanDevice.cpp
index fc144af..8abb6ac 100644
--- a/source/backend/vulkan/component/VulkanDevice.cpp
+++ b/source/backend/vulkan/component/VulkanDevice.cpp
@@ -64,7 +64,7 @@ VulkanDevice::VulkanDevice(std::shared_ptr<VulkanInstance> instance, const std::
     VkPhysicalDeviceFeatures mDeviceFeature;
     ::memset(&mDeviceFeature, 0, sizeof(mDeviceFeature));
     mDeviceFeature.shaderStorageImageWriteWithoutFormat = VK_TRUE;
-    //vkGetPhysicalDeviceFeatures(mPhysicalDevice, &mDeviceFeature);
+    vkGetPhysicalDeviceFeatures(mPhysicalDevice, &mDeviceFeature);

     VkDeviceCreateInfo deviceCreateInfo{
         /* .sType                   = */ VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO,
$ ./mnn_perf-test --only-test v3 --backend=v
Creating MNN Interpreter: mobilenetv3_large_100
The device support i8sdot:0, support fp16:0, support i8mm: 0
(index: 792,  score: 1.056641), (index: 620,  score: 1.003906), (index: 892,  score: 0.818359),
$ ./mnn_perf-test --only-test res --backend=v
Creating MNN Interpreter: resnet50
The device support i8sdot:0, support fp16:0, support i8mm: 0
(index: 999,  score: 0.000000), (index: 998,  score: 0.000000), (index: 997,  score: 0.000000),
  • patch 2
diff --git a/source/backend/vulkan/component/VulkanDevice.cpp b/source/backend/vulkan/component/VulkanDevice.cpp
index fc144af..2ba4b6e 100644
--- a/source/backend/vulkan/component/VulkanDevice.cpp
+++ b/source/backend/vulkan/component/VulkanDevice.cpp
@@ -8,6 +8,16 @@

 #include "backend/vulkan/component/VulkanDevice.hpp"
 #include <string.h>
+#define VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_8BIT_STORAGE_FEATURES_KHR (VkStructureType)1000177000
+typedef struct VkPhysicalDevice8BitStorageFeaturesKHR
+{
+    VkStructureType sType;
+    void* pNext;
+    VkBool32 storageBuffer8BitAccess;
+    VkBool32 uniformAndStorageBuffer8BitAccess;
+    VkBool32 storagePushConstant8;
+} VkPhysicalDevice8BitStorageFeaturesKHR;
+
 //#define MNN_VULKAN_PRINT_EXT
 namespace MNN {
 VulkanDevice::VulkanDevice(std::shared_ptr<VulkanInstance> instance, const std::vector<const char*>& device_extensions)
@@ -66,6 +76,46 @@ VulkanDevice::VulkanDevice(std::shared_ptr<VulkanInstance> instance, const std::
     mDeviceFeature.shaderStorageImageWriteWithoutFormat = VK_TRUE;
     //vkGetPhysicalDeviceFeatures(mPhysicalDevice, &mDeviceFeature);

+    std::vector<const char*> enabledExtensions = {
+        "VK_KHR_8bit_storage",
+        "VK_KHR_16bit_storage",
+        "VK_KHR_bind_memory2",
+        "VK_KHR_buffer_device_address",
+        "VK_KHR_create_renderpass2",
+        "VK_KHR_dedicated_allocation",
+        "VK_KHR_descriptor_update_template",
+        "VK_KHR_external_memory",
+        "VK_KHR_get_memory_requirements2",
+        "VK_KHR_maintenance1",
+        "VK_KHR_maintenance2",
+        "VK_KHR_maintenance3",
+        "VK_KHR_multiview",
+        "VK_KHR_shader_float_controls",
+        "VK_KHR_storage_buffer_storage_class",
+        "VK_KHR_swapchain",
+        "VK_EXT_memory_budget"
+    };
+
+    void* enabledExtensionFeatures = 0;
+    // enable int8 storage
+    VkPhysicalDevice8BitStorageFeaturesKHR enabled8BitStorageFeatures;
+    enabled8BitStorageFeatures.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_8BIT_STORAGE_FEATURES_KHR;
+    enabled8BitStorageFeatures.pNext = enabledExtensionFeatures;
+    enabled8BitStorageFeatures.storageBuffer8BitAccess = 1;
+    enabled8BitStorageFeatures.uniformAndStorageBuffer8BitAccess = VK_FALSE;
+    enabled8BitStorageFeatures.storagePushConstant8 = VK_FALSE;
+    enabledExtensionFeatures = &enabled8BitStorageFeatures;
+
+    // enable fp16/int16 storage
+    VkPhysicalDevice16BitStorageFeaturesKHR enabled16BitStorageFeatures;
+    enabled16BitStorageFeatures.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_16BIT_STORAGE_FEATURES_KHR;
+    enabled16BitStorageFeatures.pNext = enabledExtensionFeatures;
+    enabled16BitStorageFeatures.storageBuffer16BitAccess = 1;
+    enabled16BitStorageFeatures.uniformAndStorageBuffer16BitAccess = VK_FALSE;
+    enabled16BitStorageFeatures.storagePushConstant16 = VK_FALSE;
+    enabled16BitStorageFeatures.storageInputOutput16 = VK_FALSE;
+    enabledExtensionFeatures = &enabled16BitStorageFeatures;
+
     VkDeviceCreateInfo deviceCreateInfo{
         /* .sType                   = */ VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO,
         /* .pNext                   = */ nullptr,
@@ -78,6 +128,12 @@ VulkanDevice::VulkanDevice(std::shared_ptr<VulkanInstance> instance, const std::
         /* .ppEnabledExtensionNames = */ device_extensions.data(),
         /* .pEnabledFeatures        = */ &mDeviceFeature,
     };
+
+    deviceCreateInfo.pNext = enabledExtensionFeatures;
+    deviceCreateInfo.enabledExtensionCount = enabledExtensions.size();
+    deviceCreateInfo.ppEnabledExtensionNames = enabledExtensions.data();
+    deviceCreateInfo.pEnabledFeatures = 0; // VkPhysicalDeviceFeatures pointer
+
     CALL_VK(vkCreateDevice(mPhysicalDevice, &deviceCreateInfo, nullptr, &mDevice));
     vkGetPhysicalDeviceProperties(mPhysicalDevice, &mDeviceProty);
     vkGetPhysicalDeviceMemoryProperties(mPhysicalDevice, &mMemoryProty);

照着ncnn的运行trace添加的,但是结果还是一模一样的错误。。。。

照着ncnn的源码添加的,但是结果还是一模一样的错误。。。。

khadas vim3的开源panfrost vulkan错的更离谱。。。

⚠️ **GitHub.com Fallback** ⚠️