Blocking on Vectors - libxsmm/tpp-mlir GitHub Wiki
Introduce support for register blocking on vector
-dialect ops, or downstream ops operating on vector
type, i.e. allow for converting outer dims to loops at the vector
level.
Two-fold crux:
- Tiling on
linalg
is great ... if your pipeline goes throughlinalg
Not all paths to
vector
go throughlinalg
(e.g. coming out of Triton or in non-tensor compilers that want vectorized code). As such, vector ops might not yet be wrapped up in the appropriate loop structure, e.g. in terms of register blocking. Such loop structure can be crucial for the performance of compiler-generated microkernels (e.g., see this repo).
linalg
-level decision making, e.g. re. tiling, should be as agnostic as possible to hardware considerations
In order to separate concerns, we should allow the
vector
part of the pipeline to make decisions which account for, e.g., register sizes/instruction widths and number of registers, so that it can control, e.g., register spilling and repeated prefetching. Instead of being forced to unroll outer dims, allowing the intermediary step of converting to loops also allows us to exercise control over, e.g., code bloat at thevector
level. Further, tiling at thelinalg
-on-tensor
level can then be restricted to account for 1) making sure operands fit in appropriate levels of the memory hierarchy and 2) (inner) dims are a multiple of SIMD/SIMT widths.
Upstream has previously declared that tiling on vector ops is inappropriate, paraphrased as "you should do that on linalg" [0] [1].
Currently, when the linalg
-level part of the pipeline finishes, right before "vectorization", the IR it has produced must be such that ops (operating on tensors) can only be unrolled (after vector
-level linearization). This implies that this upper part of the pipeline must already know which size of tensors (unrolled as lower-D vectors) will fit in registers. This is something that might not be known in that part of the pipeline (e.g. we get handed some Triton IR that has not be written with our hardware in mind) or we want the upper part of our pipeline to be agnostic to hardware specifics so that the lower-down different-hardware-version-aware part of the pipeline can be wholly concerned with getting the final loops right (which might, e.g., be different across hardware variants/generations).
Starting with "vectorized" IR, e.g. where higher-level tiling has given rise to ops of about the following sizes,
func.func @brgemm_blocking(%arg0: memref<128x256x512xf32>, %arg1: memref<128x512x256xf32>, %arg2: memref<256x256xf32>) {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<128x256x512xf32>, vector<128x256x512xf32>
%1 = vector.transfer_read %arg1[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<128x512x256xf32>, vector<128x512x256xf32>
%2 = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<256x256xf32>, vector<256x256xf32>
%3 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %0, %1, %2 : vector<128x256x512xf32>, vector<128x512x256xf32> into vector<256x256xf32>
vector.transfer_write %3, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<256x256xf32>, memref<256x256xf32>
return
}
we want to end up with IR like
func.func @brgemm_blocking(%arg0: memref<128x256x512xf32>, %arg1: memref<128x512x256xf32>, %arg2: memref<256x256xf32>) {
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
%c256 = arith.constant 256 : index
%c8 = arith.constant 8 : index
%c32 = arith.constant 32 : index
%c128 = arith.constant 128 : index
%c1 = arith.constant 1 : index
%c512 = arith.constant 512 : index
scf.for %arg3 = %c0 to %c256 step %c8 {
scf.for %arg4 = %c0 to %c256 step %c32 {
%subview = memref.subview %arg2[%arg3, %arg4] [8, 32] [1, 1] : memref<256x256xf32> to memref<8x32xf32, strided<[256, 1], offset: ?>>
%0 = vector.transfer_read %subview[%c0, %c0], %cst {in_bounds = [true, true]} : memref<8x32xf32>, vector<8x32xf32>
%1 = scf.for %arg5 = %c0 to %c128 step %c1 iter_args(%arg6 = %0) -> (vector<8x32xf32>) {
%2 = scf.for %arg7 = %c0 to %c512 step %c1 iter_args(%arg8 = %arg6) -> (vector<8x32xf32>) {
%subview_0 = memref.subview %arg0[%arg5, %arg3, %arg7] [1, 8, 1] [1, 1, 1] : memref<128x256x512xf32> to memref<1x8x1xf32, strided<[131072, 512, 1], offset: ?>>
%subview_1 = memref.subview %arg1[%arg5, %arg7, %arg4] [1, 1, 32] [1, 1, 1] : memref<128x512x256xf32> to memref<1x1x32xf32, strided<[131072, 256, 1], offset: ?>>
%3 = vector.transfer_read %subview_0[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x8x1xf32, strided<[131072, 512, 1], offset: ?>>, vector<1x8x1xf32>
%4 = vector.transfer_read %subview_1[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x1x32xf32, strided<[131072, 256, 1], offset: ?>>, vector<1x1x32xf32>
%5 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %3, %4, %arg8 : vector<1x8x1xf32>, vector<1x1x32xf32> into vector<8x32xf32>
scf.yield %5 : vector<8x32xf32>
}
scf.yield %2 : vector<8x32xf32>
}
vector.transfer_write %1, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<8x32xf32>, memref<8x32xf32>
}
}
return
}
0. Make TilingInterface
, or rather a split off sub-interface, suitably abstract to be implementable by downstream ops, including those operating on vector types
Split out the tiling portion (in this case of outer loops) of TilingInterface to a separate interface and make the tileUsingSCF transform use just this interface and make it properly agnostic as to which shaped types (so including vector types) it is dealing with. With this in place we could make downstream ops (on vector type) implement this interface and reuse tileUsingSCF on them.
Pros
- Do not force downstreams to reinvent tiling/blocking abstractions
-
TilingInterface
is currently wearing too many hats - a slimmed down interface for just tiling/blocking makes sense
Cons
- Backdoor to having loop transformations on ops which operate on vector type?
- This objection has not been properly substantiated, and neither is the suggestion here to implement this interface on vector ops upstream
- ???
module {
func.func @vector_matmul(%arg0: memref<64x128xf32>, %arg1: memref<128x64xf32>, %arg2: memref<64x64xf32>) -> memref<64x64xf32> {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<64x128xf32>, vector<64x128xf32>
%1 = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : memref<128x64xf32>, vector<128x64xf32>
%2 = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<64x64xf32>, vector<64x64xf32>
%3 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %0, %1, %2 : vector<64x128xf32>, vector<128x64xf32> into vector<64x64xf32>
vector.transfer_write %3, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<64x64xf32>, memref<64x64xf32>
return %arg2 : memref<64x64xf32>
}
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%matmuls = transform.structured.match ops{["vector.contract"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%tiled, %loops:3 = transform.structured.tile_using_for %matmuls tile_sizes [8,32,1] : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">,!transform.op<"scf.for">,!transform.op<"scf.for">)
transform.yield
}
}
yields
func.func @vector_matmul(%arg0: memref<64x128xf32>, %arg1: memref<128x64xf32>, %arg2: memref<64x64xf32>) -> memref<64x64xf32> {
%c32 = arith.constant 32 : index
%c8 = arith.constant 8 : index
%c128 = arith.constant 128 : index
%c64 = arith.constant 64 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<64x128xf32>, vector<64x128xf32>
%1 = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : memref<128x64xf32>, vector<128x64xf32>
%2 = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<64x64xf32>, vector<64x64xf32>
%3 = scf.for %arg3 = %c0 to %c64 step %c8 iter_args(%arg4 = %2) -> (vector<64x64xf32>) {
%4 = scf.for %arg5 = %c0 to %c64 step %c32 iter_args(%arg6 = %arg4) -> (vector<64x64xf32>) {
%5 = scf.for %arg7 = %c0 to %c128 step %c1 iter_args(%arg8 = %arg6) -> (vector<64x64xf32>) {
%6 = vector.extract_strided_slice %0 {offsets = [1, 1], sizes = [8, 1], strides = [1, 1]} : vector<64x128xf32> to vector<8x1xf32> // offsets -> [%arg3, %arg7]
%7 = vector.extract_strided_slice %1 {offsets = [1, 1], sizes = [1, 32], strides = [1, 1]} : vector<128x64xf32> to vector<1x32xf32> // offsets -> [%arg5, %arg7]
%8 = vector.extract_strided_slice %arg8 {offsets = [1, 1], sizes = [8, 32], strides = [1, 1]} : vector<64x64xf32> to vector<8x32xf32> // offsets -> [%arg3, %arg5]
%9 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %6, %7, %8 : vector<8x1xf32>, vector<1x32xf32> into vector<8x32xf32>
%10 = vector.insert_strided_slice %9, %arg8 {offsets = [1, 1], strides = [1, 1]} : vector<8x32xf32> into vector<64x64xf32> // offsets -> [%arg3, %arg5]
scf.yield %10 : vector<64x64xf32>
}
scf.yield %5 : vector<64x64xf32>
}
scf.yield %4 : vector<64x64xf32>
}
vector.transfer_write %3, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<64x64xf32>, memref<64x64xf32>
return %arg2 : memref<64x64xf32>
}
Note the above IR is produced by a PoC implementation: the offsets are still incorrect as
vector.extract_strided_slice
andvector.insert_strided_slice
do not yet allowoffsets
to be dynamic. From looking at the code for it and the associated lowerings, e.g. tovector.insert
(which does support dynamic offsets), allowingoffsets
to be dynamic shouldn't be a problem.
From the above IR, we probably want to either
- lower
vector.extract_strided_slice
andvector.insert_strided_slice
to multiplevector.extract
andvector.insert
etc. - introduce a new pass to replace the
vector.extract_strided_slice
s andvector.insert_strided_slice
with thevector.transfer_read
s andvector.transfer_write
they extract from/insert to along with correspondingmemref.subview
s. This approach allows us to directly obtain the desired IR.- We would probably first want to do the hoisting of K tile's
vector.extract_strided_slice
andvector.insert_strided_slice
outside the K-loop while the loops are still dealing with pure value semantics.
- We would probably first want to do the hoisting of K tile's
Pros
- Embraces that
vector.contract
was developed as avector
-level structured op - Can reuse existing transforms with few changes
- Enables moving transforms, e.g.
transform.structured.tile_using_for
, up and down the pipeline with ease
- Enables moving transforms, e.g.
- Can use
vector
's value semantics to deal with optimizations such as hoisting
Cons
-
TilingInterface
is intended for more things than we'd support (e.g. no fusion) - Gives an indirect path to desired IR, e.g. need to go through
vector.extract_strided_slice
andvector.insert_strided_slice
and then through a new pass (see above)
TODO: list necessary RegisterBlockingInterface
's methods and contrast with the TilingInterface
's methods
Pros
- Dedicated interface which is able to restrict itself to "tiling" outer dimensions
- depending on overlap with TilingInterface's methods, we could make TilingInterface extend this interface (equivalently: extract just the "tile outer dims" methods of TilingInterface to a separate interface that TilingInterface extends)
- Vs. Solution 1, interface/transforms/driver can be specialized for retrieving hardware relevant parameters
Cons
- Likely to significantly mirror (and code duplicate)
TilingInterface
and transforms liketileUsingScf
Basically, lifting repeated code sections back into loops.
Pros
- One transform that would be completely non-invasive w.r.t. the upstream codebase
- ???
Cons
- Against the entire idea of gradual lowering - we would be reconstructing (loop) structure we had only just right beforehand
- TODO: ...
Solution 2 - External Model for TilingInterface on vector ops that are registered and attached by downstream mlir-opt
-equivalent
Use the addExtension
and attachInterface
functionality to "late bind" a TilingInterface
implementation for the vector
ops that we care about. This has the following advantages over Solution 1:
- Do not reimplement tiling logic
- much less code to implement
- No custom transforms, just enabling of upstream transforms
- corollary is that
tileUsingSCF
/tileUsingFor
is more likely to "commute" as we move it up and down the pipeline
- corollary is that
- clear path to upstreaming code
Here's a upstream example, for tensor.pad
:
The only place this function is called is
Concretely, we would implement the "external model", i.e. a VectorContractTiling struct, for vector.contract in tpp-mlir
and do its attaching and registration right after the registerAllDialects(...) call in tpp-opt
.
Upon dialect initialization, the Dialect::initialize()
function "makes promises" regarding which ops impl which interfaces. Here is TensorDialect
promising that tensor::PadOp
has the TilingInterface
: https://github.com/llvm/llvm-project/blob/15c2d1b328433d2c26327e072059c8960469d378/mlir/lib/Dialect/Tensor/IR/TensorDialect.cpp#L66
If this is crucial, then the above scheme is unlikely to work (as-is at least). The relevant method declarePromisedInterfaces
appears to be a public and normally invoked during registerAllDialects(...)
. Maybe we can make TensorDialect
"make the promise" after the register call. Maybe.
Update: this does not seem necessary. Can get to calls on my downstream external model impl of TilingInterface attached downstream to vector.contract.
There is at least one stage in tileUsingSCF
which presumes its operating on tensors. We would need to change the following:
Solution: Just pull the following up and in front of the getOrCreateDestinations
in createInitialTensorsForTiling
: https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp#L84-L87
The following snag is also solved by making vector.contract implement DestinationStyleOpInterface: https://github.com/llvm/llvm-project/blob/ba9bd22e1b535a1669e3918fa77f6edaf6851d9a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp#L382
The main blocker is that SCF's tiling driver assumes that when it operates on value producing ops, those values produced are tensors. As such it only inserts tensor.insert_slice
ops while we need it to insert vector.insert_strided_slice
ops: https://github.com/llvm/llvm-project/blob/da6d5fa79a558b66c281bed3f5ce848a69a65208/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp#L448
vector.extract_strided_slice
assumes offsets (and sizes and strides) can be provided as attributes (i.e. statically at compile time). The TilingInterface assumes that tiling works through increasing the offsets dynamically along the specified dimension(s). See here:
https://github.com/llvm/llvm-project/blob/b2aba39001f6909965c4a9af47969e83717601c0/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td#L1239
This approach would allow us to demonstrate a use case/lowering path not going through linalg which needs the tiling - whilst using the upstream transforms - and the code that we end up with is essentially the code that would go upstream (were our approach to be accepted later on).