transfer_read with subview - Joejiong/buddy-mlir GitHub Wiki

/Workspace/buddy-mlir/llvm/build/bin/mlir-opt transfer_read_subview.mlir \             pointwise_opt [3b51857] modified untracked
-convert-vector-to-scf \
-lower-affine -convert-scf-to-std \
-convert-vector-to-llvm \
-convert-memref-to-llvm \
-convert-std-to-llvm \
-reconcile-unrealized-casts | /Workspace/buddy-mlir/llvm/build/bin/mlir-cpu-runner -e entry -entry-point-result=void \
-shared-libs=/Workspace/buddy-mlir/llvm/build/lib/libmlir_c_runner_utils.so.14git

memref.global "private" @gv : memref<5x6xf32> =
    dense<[[0. , 1. , 2. , 3. , 4. , 5. ],
           [10., 11., 12., 13., 14., 15.],
           [20., 21., 22., 23., 24., 25.],
           [30., 31., 32., 33., 34., 35.],
           [40., 41., 42., 43., 44., 45.]]>

#map0 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
#map1 = affine_map<(d0, d1) -> (6 * d0 + 2 * d1)>

// Non-contiguous, strided load.
func @transfer_read_1d(%A : memref<?x?xf32>, %base1 : index, %base2 : index) {
  %fm42 = arith.constant -42.0: f32
  %f = vector.transfer_read %A[%base1, %base2], %fm42
      {permutation_map = affine_map<(d0, d1) -> (d0)>}
      : memref<?x?xf32>, vector<9xf32>
  vector.print %f: vector<9xf32>
  return
}

// Vector load with unit stride only on last dim.
func @transfer_read_1d_unit_stride(%A : memref<?x?xf32>) {
  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %c2 = arith.constant 2 : index
  %c3 = arith.constant 3 : index
  %c4 = arith.constant 4 : index
  %c5 = arith.constant 5 : index
  %c6 = arith.constant 6 : index
  %fm42 = arith.constant -42.0: f32
  scf.for %arg2 = %c1 to %c5 step %c2 {
    scf.for %arg3 = %c0 to %c6 step %c3 {
      %0 = memref.subview %A[%arg2, %arg3] [1, 2] [1, 1]
          : memref<?x?xf32> to memref<1x2xf32, #map0>
      %1 = vector.transfer_read %0[%c0, %c0], %fm42 {in_bounds=[true]}
          : memref<1x2xf32, #map0>, vector<2xf32>
      vector.print %1 : vector<2xf32>
    }
  }
  return
}

func @entry() {
  %c0 = arith.constant 0: index
  %c1 = arith.constant 1: index
  %c2 = arith.constant 2: index
  %c3 = arith.constant 3: index
  %0 = memref.get_global @gv : memref<5x6xf32>
  %A = memref.cast %0 : memref<5x6xf32> to memref<?x?xf32>

  call @transfer_read_1d(%A, %c1, %c2) : (memref<?x?xf32>, index, index) -> ()

  // 2.a. Read 1D vector from 2D memref with non-unit stride on first dim.
  // call @transfer_read_1d_unit_stride(%A) : (memref<?x?xf32>) -> ()
  // CHECK: ( 10, 11 )
  // CHECK: ( 13, 14 )
  // CHECK: ( 30, 31 )
  // CHECK: ( 33, 34 )
  
  return
}
⚠️ **GitHub.com Fallback** ⚠️