bfloat16 Arm - AshokBhat/ml GitHub Wiki
- Four new instructions
- Available in
SVE
,Neon
(AArch64
andAArch32
) - Accept
BF16
inputs, Does not generateBF16
results - Accumulating an
FP32
intermediate result, for better accuracy
Instruction | Description |
NEON syntax |
---|---|---|
BFCVT | fp32 to bfloat16 (scalar) | BFCVT <Hd>, <Sn> |
BFCVTN | fp32 to bfloat16 (vector) Write to lower half and clear upper |
BFCVTN <Vd>.4H, <Vn>.4S BFCVTN <Vd>.8H, <Vn>.4S
|
BFCVTN2 | fp32 to bfloat16 (vector) Write to upper half |
BFCVTN2 <Vd>.4H, <Vn>.4S BFCVTN2 <Vd>.8H, <Vn>.4S
|
Instruction | Description | BMult Ops |
NEON syntax |
---|---|---|---|
BFMLALB | Even elements MAC | 4 | BFMLALB <Vd>.4S, <Vn>.8H, <Vm>.8H |
BFMLALT | Odd elements MAC | 4 | BFMLALT <Vd>.4S, <Vn>.8H, <Vm>.8H |
BFDOT |
[1x2] × [2×1] MAC |
4 8 |
BFDOT <Vd>.2S, <Vn>.4H, <Vm>.4H BFDOT <Vd>.4S, <Vn>.8H, <Vm>.8H
|
BFMMLA |
[2×4] × [4×2] MAC |
16 | BFMMLA <Vd>.4S, <Vn>.8H, <Vm>.8H |
<16-bit> | H7 | H6 | H5 | H4 | H3 | H2 | H1 | H0 |
---|---|---|---|---|---|---|---|---|
Input 1 | a7 | a6 | a5 | a4 | a3 | a2 | a1 | a0 |
Input 2 | b7 | b6 | b5 | b4 | b3 | b2 | b1 | b0 |
<32-bit> | S3+= | S2+= | S1+= | S0+= |
---|---|---|---|---|
BFMLALB | a6b6 | a4b4 | a2b2 | a0b0 |
BFMLALT | a7b7 | a5b5 | a3b3 | a1b1 |
BFDOT | a7b7 + a6b6 | a5b5 + a4b4 | a3b3 + a2b2 | a1b1 + a0b0 |
BFMMLA | a7b7 + a6b6 + a5b5 + a4b4 |
a7b3 + a6b2 + a5b1 + a4b0 |
a3b7 + a2b6 + a1b5 + a0b4 |
a3b3 + a2b2 + a1b1 + a0b0 |
- A
[1×2]
×[2×1]
dot product ofBF16
elements - Accumulating into each
FP32
element within a SIMD result.
Characterstic | Description |
---|---|
Operation | Accumulator += Inp1 * Inp2 |
Precision | FP32 += BF16 * BF16 |
Matrix size | 2x1 += 2x1 * 1x2 |
Elements | 1 element += 1 element * 1 element |
Register size | 128-bit += 128-bit * 128-bit |
BF16 mult ops | 4 or 8 ops |
BFDOT <Vd>.2S, <Vn>.4H, <Vm>.4H
// f32 in Vd <= [2x1] in Vn x [1x2] in Vm ; 2 ops
Vd.S[0] += (Vn.H[0] x Vm.H[0]) + (Vn.H[1] x Vm.H[1]))
Vd.S[1] += (Vn.H[2] x Vm.H[2]) + (Vn.H[3] x Vm.H[3]))
BFDOT <Vd>.4S, <Vn>.8H, <Vm>.8H
// f32 in Vd <= [2x1] in Vn x [1x2] in Vm ; 4 ops
Vd.S[0] += (Vn.H[0] x Vm.H[0]) + (Vn.H[1] x Vm.H[1]))
Vd.S[1] += (Vn.H[2] x Vm.H[2]) + (Vn.H[3] x Vm.H[3]))
Vd.S[2] += (Vn.H[4] x Vm.H[4]) + (Vn.H[5] x Vm.H[5]))
Vd.S[3] += (Vn.H[6] x Vm.H[6]) + (Vn.H[7] x Vm.H[7]))
-
BFloat16
floating-point matrix multiply-accumulate into 2x2 matrix result[2, 2] = addend[2, 2] + (op1[2, 4] * op2[4, 2])
Characterstic | Description |
---|---|
Operation | Accumulator += Inp1 * Inp2 |
Precision | FP32 += BF16 * BF16 |
Matrix size | 2x2 += 2x4 * 4x2 |
Register size | 128-bit += 128-bit * 128-bit |
BF16 mult throughput | >= two BFDOT instruction |
Ops | 16 BF16 mults |
BFMMLA <Vd>.4S, <Vn>.8H, <Vm>.8H
Vd.S[0] += (Vn.H[0] x Vm.H[0]) + (Vn.H[1] x Vm.H[1]) + (Vn.H[2] x Vm.H[2]) + (Vn.H[3] x Vm.H[3])
Vd.S[1] += (Vn.H[0] x Vm.H[4]) + (Vn.H[1] x Vm.H[5]) + (Vn.H[2] x Vm.H[6]) + (Vn.H[3] x Vm.H[7])
Vd.S[2] += (Vn.H[4] x Vm.H[0]) + (Vn.H[5] x Vm.H[1]) + (Vn.H[6] x Vm.H[2]) + (Vn.H[7] x Vm.H[3])
Vd.S[3] += (Vn.H[4] x Vm.H[4]) + (Vn.H[5] x Vm.H[5]) + (Vn.H[6] x Vm.H[6]) + (Vn.H[7] x Vm.H[7])
- A simple product of the even or odd
BF16
elements - Accumulating into each
FP32
element within a SIMD result.
BFMLALB <Vd>.4S, <Vn>.8H, <Vm>.8H // V[d] += FP32(V[n].even) x FP32(V[m].even) ; 4 ops
BFMLALT <Vd>.4S, <Vn>.8H, <Vm>.8H // V[d] += FP32(V[n].odd) x FP32(V[m].odd) ; 4 ops
- BFCVT - fp32 to bfloat16 (scalar)
- BFCVTN - fp32 to bfloat16 (vector)
- Write to lower half of the destination vector
- Clear the upper half to zero
- BFCVTN2 - fp32 to bfloat16 (vector)
- Write to upper half of the destination vector
- Do not modify the lower part
BFCVT <Hd>, <Sn> // Vd.H[0] = BF16(Vn.S[0])
BFCVTN <Vd>.4H, <Vn>.4S // bf16 <= fp32, clear upper ; 4 ops
BFCVTN <Vd>.8H, <Vn>.4S // bf16 <= fp32, clear upper ; 8 ops
BFCVTN2 <Vd>.4H, <Vn>.4S // bf16 <= fp32, write upper ; 4 ops
BFCVTN2 <Vd>.8H, <Vn>.4S // bf16 <= fp32, write upper ; 8 ops