Kernel Design - minhsqtruong/FastGeMM GitHub Wiki

Kernel Design

Theoretical Peak Calculation

From link we know that the 7.x compute capability GPU has 32-bits multiply accumulate throughput of 64 floating point operations per cycle. And since there are 64 SM on the GPU, operating at 1.62 GHz, the theoretical throughput of matrix multiply is 3.732 TFLOPS (assuming compute bound). Our fastgemm implementation tries to get close to this number.

Kernel Parameters

Given a Matrix Multiply operation as follow: We select (m,n,k) = (24,512,k) for each SM. To justify this choice of tuple, consider the following diagram: For each k, an outer product is computed with a/b be row/col vectors of the aforementioned A/B matrices. Therefore, each k iteration is a partial update on C. Therefore, C is reused throughout multiple iteration. We choose to store C in shared memory because it is close to the ALU with large storage capacity. Knowing shared memory size equals 49152 bytes, we can store 49152/4 = 12288 C insides each SM. Then, m*n = 12288, with a possible solution of (m,n) = (24,512). From the diagram, it is clear that a vector is reused for multiple b vectors when computing the outer product. Therefore, we put a vector in each thread private registers (this explain why m tends to be small). We decide on m = 24 so that n = 12288/24 = 512. With 4 warp schedulers and 32 threads per warps, this means each thread works on 512/(4*32) = 4 elements of b vectors for each k iteration. 4 is a good number because if data is packed in float4, each k iteration only requires 1 load of b. To choose k, it is best to plot kernel throughput over different k values.

Preloading a elements into registers

Since each k iteration requires a different vector of A, we need to load in new a vectors every iteration. Since m=24, this means each threads have to load 6 float4 values into their private registers. This latency can be hidden through preloading the next a vector into registers. This means we need to use 48 registers in total, which is still under 256 registers max per thread.