Blocking on Vectors - libxsmm/tpp-mlir GitHub Wiki

Blocking/tiling of ops on vector type - i.e. converting outer dims to loops on vector dialect

Proposal

Introduce support for register blocking on vector-dialect ops, or downstream ops operating on vector type, i.e. allow for converting outer dims to loops at the vector level.

Rationale

Two-fold crux:

  1. Tiling on linalg is great ... if your pipeline goes through linalg

Not all paths to vector go through linalg (e.g. coming out of Triton or in non-tensor compilers that want vectorized code). As such, vector ops might not yet be wrapped up in the appropriate loop structure, e.g. in terms of register blocking. Such loop structure can be crucial for the performance of compiler-generated microkernels (e.g., see this repo).

  1. linalg-level decision making, e.g. re. tiling, should be as agnostic as possible to hardware considerations

In order to separate concerns, we should allow the vector part of the pipeline to make decisions which account for, e.g., register sizes/instruction widths and number of registers, so that it can control, e.g., register spilling and repeated prefetching. Instead of being forced to unroll outer dims, allowing the intermediary step of converting to loops also allows us to exercise control over, e.g., code bloat at the vector level. Further, tiling at the linalg-on-tensor level can then be restricted to account for 1) making sure operands fit in appropriate levels of the memory hierarchy and 2) (inner) dims are a multiple of SIMD/SIMT widths.

Context

Upstream has previously declared that tiling on vector ops is inappropriate, paraphrased as "you should do that on linalg" [0] [1].

Limitations of status quo

Currently, when the linalg-level part of the pipeline finishes, right before "vectorization", the IR it has produced must be such that ops (operating on tensors) can only be unrolled (after vector-level linearization). This implies that this upper part of the pipeline must already know which size of tensors (unrolled as lower-D vectors) will fit in registers. This is something that might not be known in that part of the pipeline (e.g. we get handed some Triton IR that has not be written with our hardware in mind) or we want the upper part of our pipeline to be agnostic to hardware specifics so that the lower-down different-hardware-version-aware part of the pipeline can be wholly concerned with getting the final loops right (which might, e.g., be different across hardware variants/generations).


Motivating IR example

Starting with "vectorized" IR, e.g. where higher-level tiling has given rise to ops of about the following sizes,

  func.func @brgemm_blocking(%arg0: memref<128x256x512xf32>, %arg1: memref<128x512x256xf32>, %arg2: memref<256x256xf32>) {
    %c0 = arith.constant 0 : index
    %cst = arith.constant 0.000000e+00 : f32
    %0 = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<128x256x512xf32>, vector<128x256x512xf32>
    %1 = vector.transfer_read %arg1[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<128x512x256xf32>, vector<128x512x256xf32>
    %2 = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<256x256xf32>, vector<256x256xf32>
    %3 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %0, %1, %2 : vector<128x256x512xf32>, vector<128x512x256xf32> into vector<256x256xf32>
    vector.transfer_write %3, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<256x256xf32>, memref<256x256xf32>
    return
  }

we want to end up with IR like

  func.func @brgemm_blocking(%arg0: memref<128x256x512xf32>, %arg1: memref<128x512x256xf32>, %arg2: memref<256x256xf32>) {
    %cst = arith.constant 0.000000e+00 : f32
    %c0 = arith.constant 0 : index
    %c256 = arith.constant 256 : index
    %c8 = arith.constant 8 : index
    %c32 = arith.constant 32 : index
    %c128 = arith.constant 128 : index
    %c1 = arith.constant 1 : index
    %c512 = arith.constant 512 : index
    scf.for %arg3 = %c0 to %c256 step %c8 {
      scf.for %arg4 = %c0 to %c256 step %c32 {
        %subview = memref.subview %arg2[%arg3, %arg4] [8, 32] [1, 1] : memref<256x256xf32> to memref<8x32xf32, strided<[256, 1], offset: ?>>
        %0 = vector.transfer_read %subview[%c0, %c0], %cst {in_bounds = [true, true]} : memref<8x32xf32>, vector<8x32xf32>
        %1 = scf.for %arg5 = %c0 to %c128 step %c1 iter_args(%arg6 = %0) -> (vector<8x32xf32>) {
          %2 = scf.for %arg7 = %c0 to %c512 step %c1 iter_args(%arg8 = %arg6) -> (vector<8x32xf32>) {
            %subview_0 = memref.subview %arg0[%arg5, %arg3, %arg7] [1, 8, 1] [1, 1, 1] : memref<128x256x512xf32> to memref<1x8x1xf32, strided<[131072, 512, 1], offset: ?>>
            %subview_1 = memref.subview %arg1[%arg5, %arg7, %arg4] [1, 1, 32] [1, 1, 1] : memref<128x512x256xf32> to memref<1x1x32xf32, strided<[131072, 256, 1], offset: ?>>
            %3 = vector.transfer_read %subview_0[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x8x1xf32, strided<[131072, 512, 1], offset: ?>>, vector<1x8x1xf32>
            %4 = vector.transfer_read %subview_1[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x1x32xf32, strided<[131072, 256, 1], offset: ?>>, vector<1x1x32xf32>
            %5 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %3, %4, %arg8 : vector<1x8x1xf32>, vector<1x1x32xf32> into vector<8x32xf32>
            scf.yield %5 : vector<8x32xf32>
          }
          scf.yield %2 : vector<8x32xf32>
        }
        vector.transfer_write %1, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<8x32xf32>, memref<8x32xf32>
      }
    }
    return
  }

Potential solutions (and their pros and cons)

0. Make TilingInterface, or rather a split off sub-interface, suitably abstract to be implementable by downstream ops, including those operating on vector types

Split out the tiling portion (in this case of outer loops) of TilingInterface to a separate interface and make the tileUsingSCF transform use just this interface and make it properly agnostic as to which shaped types (so including vector types) it is dealing with. With this in place we could make downstream ops (on vector type) implement this interface and reuse tileUsingSCF on them.

Pros

  • Do not force downstreams to reinvent tiling/blocking abstractions
  • TilingInterface is currently wearing too many hats - a slimmed down interface for just tiling/blocking makes sense

Cons

  • Backdoor to having loop transformations on ops which operate on vector type?
    • This objection has not been properly substantiated, and neither is the suggestion here to implement this interface on vector ops upstream
  • ???

1. Implement TilingInterface on relevant vector ops, e.g. vector.contract

module {
  func.func @vector_matmul(%arg0: memref<64x128xf32>, %arg1: memref<128x64xf32>, %arg2: memref<64x64xf32>) -> memref<64x64xf32> {
    %c0 = arith.constant 0 : index
    %cst = arith.constant 0.000000e+00 : f32
    %0 = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<64x128xf32>, vector<64x128xf32>
    %1 = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : memref<128x64xf32>, vector<128x64xf32>
    %2 = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<64x64xf32>, vector<64x64xf32>
    %3 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %0, %1, %2 : vector<64x128xf32>, vector<128x64xf32> into vector<64x64xf32>
    vector.transfer_write %3, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<64x64xf32>, memref<64x64xf32>
    return %arg2 : memref<64x64xf32>
  }
}
module attributes {transform.with_named_sequence} {
  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
    %matmuls = transform.structured.match ops{["vector.contract"]} in %arg1 : (!transform.any_op) -> !transform.any_op
    %tiled, %loops:3 = transform.structured.tile_using_for %matmuls tile_sizes [8,32,1] : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">,!transform.op<"scf.for">,!transform.op<"scf.for">)
    transform.yield
  }
}

yields

    func.func @vector_matmul(%arg0: memref<64x128xf32>, %arg1: memref<128x64xf32>, %arg2: memref<64x64xf32>) -> memref<64x64xf32> {
      %c32 = arith.constant 32 : index
      %c8 = arith.constant 8 : index
      %c128 = arith.constant 128 : index
      %c64 = arith.constant 64 : index
      %c1 = arith.constant 1 : index
      %c0 = arith.constant 0 : index
      %cst = arith.constant 0.000000e+00 : f32
      %0 = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<64x128xf32>, vector<64x128xf32>
      %1 = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : memref<128x64xf32>, vector<128x64xf32>
      %2 = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<64x64xf32>, vector<64x64xf32>
      %3 = scf.for %arg3 = %c0 to %c64 step %c8 iter_args(%arg4 = %2) -> (vector<64x64xf32>) {
        %4 = scf.for %arg5 = %c0 to %c64 step %c32 iter_args(%arg6 = %arg4) -> (vector<64x64xf32>) {
          %5 = scf.for %arg7 = %c0 to %c128 step %c1 iter_args(%arg8 = %arg6) -> (vector<64x64xf32>) {
            %6 = vector.extract_strided_slice %0 {offsets = [1, 1], sizes = [8, 1], strides = [1, 1]} : vector<64x128xf32> to vector<8x1xf32> // offsets -> [%arg3, %arg7]
            %7 = vector.extract_strided_slice %1 {offsets = [1, 1], sizes = [1, 32], strides = [1, 1]} : vector<128x64xf32> to vector<1x32xf32> // offsets -> [%arg5, %arg7]
            %8 = vector.extract_strided_slice %arg8 {offsets = [1, 1], sizes = [8, 32], strides = [1, 1]} : vector<64x64xf32> to vector<8x32xf32> // offsets -> [%arg3, %arg5]
            %9 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %6, %7, %8 : vector<8x1xf32>, vector<1x32xf32> into vector<8x32xf32>
            %10 = vector.insert_strided_slice %9, %arg8 {offsets = [1, 1], strides = [1, 1]} : vector<8x32xf32> into vector<64x64xf32> // offsets -> [%arg3, %arg5]
            scf.yield %10 : vector<64x64xf32>
          }
          scf.yield %5 : vector<64x64xf32>
        }
        scf.yield %4 : vector<64x64xf32>
      }
      vector.transfer_write %3, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<64x64xf32>, memref<64x64xf32>
      return %arg2 : memref<64x64xf32>
    }

Note the above IR is produced by a PoC implementation: the offsets are still incorrect as vector.extract_strided_slice and vector.insert_strided_slice do not yet allow offsets to be dynamic. From looking at the code for it and the associated lowerings, e.g. to vector.insert (which does support dynamic offsets), allowing offsets to be dynamic shouldn't be a problem.

From the above IR, we probably want to either

  1. lower vector.extract_strided_slice and vector.insert_strided_slice to multiple vector.extract and vector.insert etc.
  2. introduce a new pass to replace the vector.extract_strided_slices and vector.insert_strided_slice with the vector.transfer_reads and vector.transfer_write they extract from/insert to along with corresponding memref.subviews. This approach allows us to directly obtain the desired IR.
    • We would probably first want to do the hoisting of K tile's vector.extract_strided_slice and vector.insert_strided_slice outside the K-loop while the loops are still dealing with pure value semantics.

Pros

  • Embraces that vector.contract was developed as a vector-level structured op
  • Can reuse existing transforms with few changes
    • Enables moving transforms, e.g. transform.structured.tile_using_for, up and down the pipeline with ease
  • Can use vector's value semantics to deal with optimizations such as hoisting

Cons

  • TilingInterface is intended for more things than we'd support (e.g. no fusion)
  • Gives an indirect path to desired IR, e.g. need to go through vector.extract_strided_slice and vector.insert_strided_slice and then through a new pass (see above)

2. New RegisterBlockingInterface on relevant vector ops and a new corresponding transform

TODO: list necessary RegisterBlockingInterface's methods and contrast with the TilingInterface's methods

Pros

  • Dedicated interface which is able to restrict itself to "tiling" outer dimensions
    • depending on overlap with TilingInterface's methods, we could make TilingInterface extend this interface (equivalently: extract just the "tile outer dims" methods of TilingInterface to a separate interface that TilingInterface extends)
  • Vs. Solution 1, interface/transforms/driver can be specialized for retrieving hardware relevant parameters

Cons

  • Likely to significantly mirror (and code duplicate) TilingInterface and transforms like tileUsingScf

3. New "reroll" transform (to lift unrolled vector instruction sequences to loops)

Basically, lifting repeated code sections back into loops.

Pros

  • One transform that would be completely non-invasive w.r.t. the upstream codebase
  • ???

Cons

  • Against the entire idea of gradual lowering - we would be reconstructing (loop) structure we had only just right beforehand
  • TODO: ...

Details regarding TilingInterface on vector.contract Proof-of-Concept

Solution 2 - External Model for TilingInterface on vector ops that are registered and attached by downstream mlir-opt-equivalent

Use the addExtension and attachInterface functionality to "late bind" a TilingInterface implementation for the vector ops that we care about. This has the following advantages over Solution 1:

  1. Do not reimplement tiling logic
  2. much less code to implement
  3. No custom transforms, just enabling of upstream transforms
    • corollary is that tileUsingSCF/tileUsingFor is more likely to "commute" as we move it up and down the pipeline
  4. clear path to upstreaming code

Here's a upstream example, for tensor.pad:

https://github.com/llvm/llvm-project/blob/9d487050a144b895950a6fd48b993513a714e69d/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp#L308-L311

The only place this function is called is

https://github.com/llvm/llvm-project/blob/fb29f19fdb0b2b3c8c87cc767482d941818e92a8/mlir/include/mlir/InitAllDialects.h#L190

Concretely, we would implement the "external model", i.e. a VectorContractTiling struct, for vector.contract in tpp-mlir and do its attaching and registration right after the registerAllDialects(...) call in tpp-opt.

Potential hitch: "making promises" - Not an issue in prototype

Upon dialect initialization, the Dialect::initialize() function "makes promises" regarding which ops impl which interfaces. Here is TensorDialect promising that tensor::PadOp has the TilingInterface: https://github.com/llvm/llvm-project/blob/15c2d1b328433d2c26327e072059c8960469d378/mlir/lib/Dialect/Tensor/IR/TensorDialect.cpp#L66

If this is crucial, then the above scheme is unlikely to work (as-is at least). The relevant method declarePromisedInterfaces appears to be a public and normally invoked during registerAllDialects(...). Maybe we can make TensorDialect "make the promise" after the register call. Maybe.

Update: this does not seem necessary. Can get to calls on my downstream external model impl of TilingInterface attached downstream to vector.contract.

Hitch #2: changes to tileUsingSCF

There is at least one stage in tileUsingSCF which presumes its operating on tensors. We would need to change the following:

https://github.com/llvm/llvm-project/blob/1c6cecdbdd2470292ce0b508922d807e3100f85c/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp#L1048

https://github.com/llvm/llvm-project/blob/1c6cecdbdd2470292ce0b508922d807e3100f85c/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp#L575

Solution: Just pull the following up and in front of the getOrCreateDestinations in createInitialTensorsForTiling: https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp#L84-L87

The following snag is also solved by making vector.contract implement DestinationStyleOpInterface: https://github.com/llvm/llvm-project/blob/ba9bd22e1b535a1669e3918fa77f6edaf6851d9a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp#L382

The main blocker is that SCF's tiling driver assumes that when it operates on value producing ops, those values produced are tensors. As such it only inserts tensor.insert_slice ops while we need it to insert vector.insert_strided_slice ops: https://github.com/llvm/llvm-project/blob/da6d5fa79a558b66c281bed3f5ce848a69a65208/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp#L448

Hitch #3: vector's subview/extract_slice equivalent does not support dynamic offsets

vector.extract_strided_slice assumes offsets (and sizes and strides) can be provided as attributes (i.e. statically at compile time). The TilingInterface assumes that tiling works through increasing the offsets dynamically along the specified dimension(s). See here: https://github.com/llvm/llvm-project/blob/b2aba39001f6909965c4a9af47969e83717601c0/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td#L1239

Upstreaming considerations

This approach would allow us to demonstrate a use case/lowering path not going through linalg which needs the tiling - whilst using the upstream transforms - and the code that we end up with is essentially the code that would go upstream (were our approach to be accepted later on).

⚠️ **GitHub.com Fallback** ⚠️