SASS Study from OpenAI Kernels - chhwang/devel-note GitHub Wiki
MaxAs Overview & Workflow
What is SASS?
OpenAI kernels are written in SASS (CUDA assembly language). SASS is undocumented and its specification differs by GPU architecture. The only official document we can find about SASS contains a brief instruction set reference of each GPU architecture: https://docs.nvidia.com/cuda/cuda-binary-utilities/index.html#instruction-set-ref. PTX is a one-level-higher language than SASS. Unlike with SASS, PTX is forward-compatible and well-documented.
What is MaxAs?
MaxAs is an assembler that converts Maxwell/Pascal architecture SASS into a CUBIN file. It is a result of hacking SASS by Scott Gray from OpenAI. MaxAs wiki: https://github.com/NervanaSystems/maxas/wiki/Introduction
MaxAs allows us more precise control of GPU resources using SASS than writing PTX code.
MaxAs workflow (We have automated these steps in gemm/Makefile. Read this only if you want details.)
MaxAs first reads a CUBIN template file and inserts SASS code into the template.
How to write a CUBIN template?
Option 1: write a CUDA template and convert it into CUBIN. For instance, CUDA template code for a GEMM kernel may look like this.
<gemm.cu File Start>
extern "C" __device__ void hgemm_32x32x32_NN_vec(int blk_a, int blk_b, int cda, int cdb, int cdc,
int m, int n, int k, void *A, void *B, void *C,
float alpha, float beta)
{
extern __shared__ float share[];
}
<gemm.cu File End>
We need to compile this code as a "relocatable device code", which means that this kernel needs to be linked to another executable code. Pass "-rdc true" and a proper gencode option to NVCC for this. Example for compute capability 6.0:
$ nvcc -rdc true -cubin -gencode arch=compute_60,code=sm_60 -o gemm.cubin gemm.cu
Then you've got the template CUBIN file "gemm.cubin".
Option 2: write a PTX template and convert it into CUBIN.
Not that different with Option 1, but we select Option 2. Our PTX template generator for GEMM kernels is in gemm/ptx/gen_ptx.py
(this is a Python script).
How to write a SASS code?
Refer to our GEMM code in gemm/sass/xgemm_32x32x32.sass
.
MaxAs wiki provides details on SASS code: https://github.com/NervanaSystems/maxas/wiki/Introduction.
Now insert SASS code into the CUBIN template.
$ cd gemm/maxas/
$ ./maxas.pl -i -w -k hgemm_32x32x32_NN_vec -Dtype h -DNN 1 -Dvec 1 ../sass/xgemm_32x32x32.sass gemm.cubin
Refer to ./maxas.pl -h to know about the options.
Now "gemm.cubin" contains complete definition of the function. You can link this using CUDA JIT (refer to "app_gemm_test/test_gemm.cu"), or just by converting it into an object file:
$ nvcc -dc -gencode arch=compute_60,code=sm_60 -o gemm.o gemm.cubin
What do we have changed in MaxAs implementation?
Original MaxAs assembler cannot read relocatable device code. In turn, if we create a CUBIN template of a device kernel and try to insert SASS into it, it returns an error. We fixed this to support device kernels.
How did we fix it?
Original MaxAs assumes that the kernel in the template is a global kernel, and tries to find the code section and parameters section in the CUBIN template, but relocatable device code (device kernels) does not contain parameters section, so it returns an error. This is natural because device kernels pass parameters during runtime, while global kernels pass them via constant memory which is visible during compile time.
What we changed is just to skip parameters section if it does not exist.
How to Write Device Kernels from Global Kernels in SASS
Why do we write device kernels?
Original OpenAI SASS code defines global GEMM kernels. Since we want to call GEMM kernels inside our own global kernel, we need to fix this to define device GEMM kernels.
Before to change the code, calculate available registers first.
While global kernels read function parameters from constant memory, device kernels receive parameters via registers and local memory. This means although the kernel defines the same logic, a device kernel may take more registers than a global one.
A single SM has 65,536 registers in total. For instance, if a kernel is using 128 registers per thread and using 128 thread block size, a SM can run 65,536 / 128 / 128 = 4 blocks of this kernel concurrently. In this case, even using one more extra register may degrade the kernel throughput by reducing blocks per SM into 65,536 / 128 / 129 = 3 blocks.
Calculate how many extra registers you can use without register bottleneck. Using CUDA occupancy calculator is more accurate. http://developer.download.nvidia.com/compute/cuda/CUDA_Occupancy_calculator.xls
Note that less blocks per SM does not always mean severe throughput degradation. In xgemm_32x32x32.sass case, using 2 blocks per SM or 3 blocks per SM performed similar in my experiment settings. This result depends on the length of inner dimension (k value), the type of kernel implementation, and the computing power of GPU.
What is changed in xgemm_32x32x32.sass
?
OpenAI GEMM has 6 different SASS implementation, and currently we are using xgemm_32x32x32.sass
only. Others are not working, but we may support them also in near future.
Followings are the details of the change.
Backgrounds
NOTE: explanation below except about MaxAs is my own finding and not from an official document. Those may change in future architectures or CUDA versions. I figured these out by compiling a test CUDA code and reading the disassembled code via nvdisasm or MaxAs (using ./maxas.pl -e
option).
SASS Control code. (Read this only if you want to understand SASS code further.)
Each instruction in SASS code has a preceding control code. For instance:
E:D:C:B:A MOV R1, R2;
E:D:C:B:A
is a control code. Each field has the following meanings.
A means the required number of instruction clocks to complete this instruction. Valid values are 0 ~ 6. Normally A is 1, which is very straight-forward, and sometimes it is 0 or 2 or larger. If it is 0, it means that this instruction will be launched concurrently with the following instruction. E.g.
--:-:-:-:0 MOV R1, R2;
--:-:-:-:1 IADD R3, R3, R4;
launches both instructions concurrently. Of course, if IADD try to read R1 value, it will read the old value, not the same value in R2. This feature is supported from Maxwell architecture only when the two instructions do not use the same computational resources, e.g. launching MOV and ADD at the same time is valid, but launching two MOVs at the same time is invalid. LDS and SDS at the same time is also invalid. If A is larger than 1, it means the given instruction takes longer time to accomplish.
B value is a dash (-
) or Y
. Dash means nothing, Y
is a hint for the compiler which indicates this instruction can be yielded. I don't know much about this field.
C value indicates "write dependency". Valid values are 1 ~ 6, and each indicates an independent flag resources. E.g.
--:-:1:-:1 LDG R1, [R2];
...
01:-:-:-:1 IADD R3, R3, R1;
It reads global memory address R2 and store the value in R1. Global memory read takes long time, so IADD needs to wait for the write flag 1, by setting E = 01, to ensure that it reads R1 after the global memory read finishes. Each flag resources has corresponding E value: 1 --> 01, 2 --> 02, 3 --> 04, 4 --> 08, 5 --> 10, 6 --> 20. These are hexadecimal bit field values, so these can be added. E.g.
--:-:3:-:1 LDG R1, [R2];
--:-:4:-:1 LDG R3, [R5];
...
0c:-:-:-:1 IADD R3, R3, R1;
IADD waits for both LDGs. Note that we have a gap (...) between LDG and IADD. This gap is necessary because enabling a flag resource takes an extra clock. If we don't have any instructions go in between, the last LDG needs to wait for 2 clocks to ensure that the flag is enabled before launching IADD. E.g.
--:-:3:-:1 LDG R1, [R2];
--:-:4:-:2 LDG R3, [R5];
0c:-:-:-:1 IADD R3, R3, R1;
D value indicates "read dependency". It is very similar with C, and shares the flag resources 1 ~ 6 with C. Read dependency is used when a following instruction needs to modify a value which is being read by a preceding instruction. E.g.
--:5:4:-:2 LDG R3, [R5];
18:-:-:-:1 IADD R5, R3, R5;
LDG is reading R5 value and IADD is modifying it, so we use flag 5
to wait until the reading finishes. LDG is also writing R3 which is read by IADD, so we use flag 4
to wait until the writing finishes.
E value is already explained above, C and D.
If you want more details, refer to MaxAs wiki: https://github.com/NervanaSystems/maxas/wiki/Introduction
Global Kernel Argument Passing. Global kernel arguments are stored in constant memory. Any value in constant memory can be used as an immediate value, just like a register value. We can access to the first 4-byte value of the first argument via c[0x0][0x140]
. Following 4-byte can be accessed via c[0x0][0x144]
and so on.
E.g. read the first 8 bytes of the arguments into R16
--:-:-:-:1 MOV R16, c[0x0][0x140];
--:-:-:-:1 MOV R17, c[0x0][0x144];
Note that R16 and R17 registers are 64-bit aligned, so we can read these values at the same time, e.g.
--:-:-:-:1 LDG.64 [R3], R16;
stores R16 and R17 values in the global memory address R3.
Device Kernel Argument Passing. Device kernel arguments are stored in registers or local memory (which physically resides in global memory). Right before the global kernel calls a device kernel, it stores device kernel arguments in 12 registers (R4 ~ R15) by order. If the size of entire arguments is larger than 12 x 4 bytes = 48 bytes, the remainders are passed via local memory. The global kernel will store values in local memory and set R1 value as the address of the first 4 bytes of remainder arguments.
Context Retrieval of Device Kernels. When global kernels return, they use EXIT instruction, while device kernels use RET instruction. Before to call RET, a device kernel need to ensure that values in the following registers are not modified from when the device kernel was called.
R1, R16 ~ R31, R36 ~ R39, R44 ~ R47, R52 ~ R55, ... , R244 ~ R247, R252, R253, R254.
(R254 is the largest numbered register in SASS.) From R36 to R247, one can find the pattern: sequential 4 registers need to be retrieved, the following 4 registers do not need to be retrieved, and the next following 4 registers need to be, and so on.
Changed order of argument passing
Original SASS code is passing arguments in the following order:
C, A, B, alpha, beta, cda, cdb, cdc, m, n, k, blk_a, blk_b
Actually, order of argument passing doesn't matter much in global kernels because all arguments will be stored in constant memory. Note that our latest code changed the name of arguments blk_a and blk_b respectively into idx_A and idx_B.
If we change this into a device kernel, the argument passing happens via registers and local memory like the following. Note that C, A, and B are 8 bytes each since they are pointers, others are 4 bytes each.
R4: Lower 4 bytes of C
R5: Higher 4 bytes of C
R6: Lower 4 bytes of A
R7: Higher 4 bytes of A
R8: Lower 4 bytes of B
R9: Higher 4 bytes of B
R10 ~ R15: alpha, beta, cda, cdb, cdc, m
Local memory[R1]: n
Local memory[R1 + 4]: k
Local memory[R1 + 8]: blk_a
Local memory[R1 + 12]: blk_b
This order of argument passing is inefficient for device kernels because n, k, blk_a, and blk_b arguments are used at the beginning stage of the GEMM algorithm. This makes the kernel wait for local memory reading at the beginning stage. Instead, we change this order to pass C, alpha, and beta via local memory, which are only used at the finishing stage of GEMM. This allows us just to issue local memory reading at the beginning stage and move on without waiting for the results (without enabling read/write flags), then results will be already prepared when we need those at the finishing stage. The changed order of argument passing is like this:
blk_a, blk_b, cda, cdb, cdc, m, n, k, A, B, C, alpha, beta
We read arguments in local memory like this:
--:-:-:-:1 LDL arg_C0, [R1];
--:-:-:-:1 LDL arg_C1, [R1+0x4];
--:-:-:-:1 LDL arg_alpha, [R1+0x8];
--:-:-:-:1 LDL arg_beta, [R1+0xc];
One may ask why am I not using 16-byte memory reading instruction (LDL.128
), but we necessarily need to read in 4 bytes unit because the address in R1 may not be 16-byte-aligned, e.g. 0xfffcc8. In this case, using 16-byte reading instruction raises misaligned memory error.
Changed register mappings
Device kernels need to retrieve values in specific registers (I will call them "special registers" here) listed in Backgrounds above. Since this is an extra work, NVCC avoids using special registers to reduce the retrieval work. The original OpenAI SASS code was ignoring this and just using all registers R0 ~ R159, so I changed its register mapping to avoid using special registers as much as possible.
It is impossible to use non-special registers only because the GEMM kernel needs a lot of registers. Though, at least, we want to use relative-high-numbered special registers. This is because the global kernel use low-numbered registers first and then use higher numbers if it needs more, so we may not need to retrieve high-numbered special registers if the global kernel is not using many registers. Thus I moved all the original register mappings into high number area. Using high number area has a shortage that it requires more registers -- I don’t know exactly why, but for instance, if a kernel uses R0 register only, the resource usage of the compiled kernel is 1 register per thread, but if a kernel uses R254 register only, the resource usage becomes 255 registers per thread. This means that the highest register number decides the number of allocated registers per thread, not the actual number of registers in use. Due to this, our current SASS code uses 255 registers per thread (2 blocks per SM) while the original code was using 159 registers (3 blocks per SM). However I still think this is fine due to two reasons. 1) When I compared 2 blocks per SM and 3 blocks per SM throughput in P100, the difference was not that significant, although this result depends on computation power of GPU. 2) ARK cannot schedule 3 blocks (tiles) computation in a single SM due to shared memory limitation. This is because ARK uses only one block per SM, and the max shared memory per block is only half of the entire shared memory in a SM.
In turn, our current SASS code uses special registers starting from R132. It means if the global kernel uses less than 132 registers, we do not need to retrieve register values at all. Also, we have some spare of non-special registers R32 ~ R35, R40 ~ R43, ..., R80 ~ R83
. If the global kernel uses more than 132 registers, we can use these non-special registers for the context retrieval by temporarily keeping special register values in these, which is possible up to R183. This context retrieval is fast because it uses registers, but still degrades kernel throughput a little. If the global kernel uses more than 183 registers, we need to use local memory to store those values, but I didn't implement this. We just want to keep the global kernel not to use many registers.
Some register mappings are not just simply moved, e.g. I changed the mapping so that tid
and tid31
share the same register (R96), and also moved writeCs
into a low number (R3). These are just minor optimization to use as little special registers as possible.
Removed tile index calculation
Original SASS code receives blk_a and blk_b as arguments and read thread indexes from special registers (SR_TID, SR_CTAID) to calculate idx_A and idx_B, which indicate the output tile indexes to calculate. I removed this part by directly passing idx_A and idx_B as arguments, so that we can manually assign a specific tile to calculate to a block. Also, since calculation of idx_A and idx_B was using many write dependencies, removing this part improved kernel throughput.