matmul first approach - Joejiong/buddy-mlir GitHub Wiki
add 2 pass
- transfer to matmul
- vectorization
- rvv support
./bin/buddy-opt ../examples/ConvOpt/pointwise_conv2d_nhwc_filter_hwcf.mlir -pointwise-conv-vectorization --print-ir-before-all
// -----// IR Dump Before {anonymous}::PointwiseConvVectorizationPass //----- //
module {
func @conv_2d_1x1(%arg0: tensor<1x4x5x2xf32>, %arg1: tensor<1x1x2x7xf32>) -> tensor<1x4x5x7xf32> {
%0 = linalg.init_tensor [1, 4, 5, 7] : tensor<1x4x5x7xf32>
%1 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x4x5x2xf32>, tensor<1x1x2x7xf32>) outs(%0 : tensor<1x4x5x7xf32>) -> tensor<1x4x5x7xf32>
return %1 : tensor<1x4x5x7xf32>
}
}
module {
func @conv_2d_1x1(%arg0: tensor<1x4x5x2xf32>, %arg1: tensor<1x1x2x7xf32>) -> tensor<1x4x5x7xf32> {
%0 = linalg.init_tensor [1, 4, 5, 7] : tensor<1x4x5x7xf32>
%1 = tensor.collapse_shape %arg0 [[0, 1, 2], [3]] : tensor<1x4x5x2xf32> into tensor<20x2xf32>
%2 = tensor.collapse_shape %arg1 [[0, 1, 2], [3]] : tensor<1x1x2x7xf32> into tensor<2x7xf32>
%3 = tensor.collapse_shape %0 [[0, 1, 2], [3]] : tensor<1x4x5x7xf32> into tensor<20x7xf32>
%4 = linalg.matmul ins(%1, %2 : tensor<20x2xf32>, tensor<2x7xf32>) outs(%3 : tensor<20x7xf32>) -> tensor<20x7xf32>
%5 = tensor.expand_shape %4 [[0, 1, 2], [3]] : tensor<20x7xf32> into tensor<1x4x5x7xf32>
return %5 : tensor<1x4x5x7xf32>
}
}
perf result:
> ./bin/pointwise-conv-2d-nhwc-hwcf-benchmark pointwise_benchmark [39d7e0c] modified
2021-12-20T18:42:07+08:00
Running ./bin/pointwise-conv-2d-nhwc-hwcf-benchmark
Run on (52 X 2500 MHz CPU s)
CPU Caches:
L1 Data 32 KiB (x26)
L1 Instruction 32 KiB (x26)
L2 Unified 1024 KiB (x26)
L3 Unified 36608 KiB (x1)
Load Average: 8.09, 11.43, 12.63
***WARNING*** Library was built as DEBUG. Timings may be affected.
-------------------------------------------------------------------------------------
Benchmark Time CPU Iterations
-------------------------------------------------------------------------------------
BM_PointwiseConv2DNhwcHwcf/100 0.124 ms 0.124 ms 5655
BM_PointwiseConv2DNhwcHwcfReturn/100 0.104 ms 0.104 ms 6728
BM_PointwiseConv2DNhwcHwcfReturnOrigin/100 0.150 ms 0.150 ms 4665
inputMemRef: [ 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 ] of shape: [ 1 4 5 2 ]
filterMemRef: [ 3 3 3 3 3 3 3 3 3 3 3 3 3 3 ] of shape: [ 1 1 2 7 ]
outputMemRef: [ 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 ] of shape: [ 1 4 5 7 ]
inputMemReturn: [ 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 ] of shape: [ 1 4 5 2 ]
filterMemReturn: [ 3 3 3 3 3 3 3 3 3 3 3 3 3 3 ] of shape: [ 1 1 2 7 ]
outputMemReturn: [ 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 ] of shape: [ 1 4 5 7 ]
inputMemReturn: [ 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 ] of shape: [ 1 4 5 2 ]
filterMemReturn: [ 3 3 3 3 3 3 3 3 3 3 3 3 3 3 ] of shape: [ 1 1 2 7 ]
outputMemReturnOrigin: [ 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 ] of shape: [ 1 4 5 7 ]