bfloat16 Arm - AshokBhat/ml GitHub Wiki

New instructions in 8.6-A

  • Four new instructions
  • Available in SVE, Neon (AArch64 and AArch32)
  • Accept BF16 inputs, Does not generate BF16 results
  • Accumulating an FP32 intermediate result, for better accuracy

Instruction table

Conversion instructions

Instruction Description NEON syntax
BFCVT fp32 to bfloat16 (scalar) BFCVT <Hd>, <Sn>
BFCVTN fp32 to bfloat16 (vector)
Write to lower half and clear upper
BFCVTN <Vd>.4H, <Vn>.4S
BFCVTN <Vd>.8H, <Vn>.4S
BFCVTN2 fp32 to bfloat16 (vector)
Write to upper half
BFCVTN2 <Vd>.4H, <Vn>.4S
BFCVTN2 <Vd>.8H, <Vn>.4S

Multiply and accumulate instructions

Instruction Description BMult Ops NEON syntax
BFMLALB Even elements MAC 4 BFMLALB <Vd>.4S, <Vn>.8H, <Vm>.8H
BFMLALT Odd elements MAC 4 BFMLALT <Vd>.4S, <Vn>.8H, <Vm>.8H
BFDOT [1x2] × [2×1] MAC 4
8
BFDOT <Vd>.2S, <Vn>.4H, <Vm>.4H
BFDOT <Vd>.4S, <Vn>.8H, <Vm>.8H
BFMMLA [2×4] × [4×2] MAC 16 BFMMLA <Vd>.4S, <Vn>.8H, <Vm>.8H

Illustration

<16-bit> H7 H6 H5 H4 H3 H2 H1 H0
Input 1 a7 a6 a5 a4 a3 a2 a1 a0
Input 2 b7 b6 b5 b4 b3 b2 b1 b0
<32-bit> S3+= S2+= S1+= S0+=
BFMLALB a6b6 a4b4 a2b2 a0b0
BFMLALT a7b7 a5b5 a3b3 a1b1
BFDOT a7b7 + a6b6 a5b5 + a4b4 a3b3 + a2b2 a1b1 + a0b0
BFMMLA a7b7 + a6b6 +
a5b5 + a4b4
a7b3 + a6b2 +
a5b1 + a4b0
a3b7 + a2b6 +
a1b5 + a0b4
a3b3 + a2b2 +
a1b1 + a0b0

Instruction details

BFDOT

  • A [1×2] × [2×1] dot product of BF16 elements
  • Accumulating into each FP32 element within a SIMD result.
Characterstic Description
Operation Accumulator += Inp1 * Inp2
Precision FP32 += BF16 * BF16
Matrix size 2x1 += 2x1 * 1x2
Elements 1 element += 1 element * 1 element
Register size 128-bit += 128-bit * 128-bit
BF16 mult ops 4 or 8 ops

Neon

BFDOT <Vd>.2S, <Vn>.4H, <Vm>.4H    
// f32 in Vd <= [2x1] in Vn x [1x2] in Vm ; 2 ops

Vd.S[0] += (Vn.H[0] x Vm.H[0]) + (Vn.H[1] x Vm.H[1])) 
Vd.S[1] += (Vn.H[2] x Vm.H[2]) + (Vn.H[3] x Vm.H[3])) 
BFDOT <Vd>.4S, <Vn>.8H, <Vm>.8H    
// f32 in Vd <= [2x1] in Vn x [1x2] in Vm ; 4 ops

Vd.S[0] += (Vn.H[0] x Vm.H[0]) + (Vn.H[1] x Vm.H[1])) 
Vd.S[1] += (Vn.H[2] x Vm.H[2]) + (Vn.H[3] x Vm.H[3])) 
Vd.S[2] += (Vn.H[4] x Vm.H[4]) + (Vn.H[5] x Vm.H[5])) 
Vd.S[3] += (Vn.H[6] x Vm.H[6]) + (Vn.H[7] x Vm.H[7])) 

BFMMLA

  • BFloat16 floating-point matrix multiply-accumulate into 2x2 matrix
  • result[2, 2] = addend[2, 2] + (op1[2, 4] * op2[4, 2])
Characterstic Description
Operation Accumulator += Inp1 * Inp2
Precision FP32 += BF16 * BF16
Matrix size 2x2 += 2x4 * 4x2
Register size 128-bit += 128-bit * 128-bit
BF16 mult throughput >= two BFDOT instruction
Ops 16 BF16 mults

Neon

BFMMLA <Vd>.4S, <Vn>.8H, <Vm>.8H 

Vd.S[0] += (Vn.H[0] x Vm.H[0]) + (Vn.H[1] x Vm.H[1]) + (Vn.H[2] x Vm.H[2]) + (Vn.H[3] x Vm.H[3])
Vd.S[1] += (Vn.H[0] x Vm.H[4]) + (Vn.H[1] x Vm.H[5]) + (Vn.H[2] x Vm.H[6]) + (Vn.H[3] x Vm.H[7]) 
Vd.S[2] += (Vn.H[4] x Vm.H[0]) + (Vn.H[5] x Vm.H[1]) + (Vn.H[6] x Vm.H[2]) + (Vn.H[7] x Vm.H[3]) 
Vd.S[3] += (Vn.H[4] x Vm.H[4]) + (Vn.H[5] x Vm.H[5]) + (Vn.H[6] x Vm.H[6]) + (Vn.H[7] x Vm.H[7]) 

BFMLAL

  • A simple product of the even or odd BF16 elements
  • Accumulating into each FP32 element within a SIMD result.

Neon

BFMLALB <Vd>.4S, <Vn>.8H, <Vm>.8H // V[d] += FP32(V[n].even) x FP32(V[m].even) ; 4 ops
BFMLALT <Vd>.4S, <Vn>.8H, <Vm>.8H // V[d] += FP32(V[n].odd) x FP32(V[m].odd) ; 4 ops

BFCVT

  • BFCVT - fp32 to bfloat16 (scalar)
  • BFCVTN - fp32 to bfloat16 (vector)
    • Write to lower half of the destination vector
    • Clear the upper half to zero
  • BFCVTN2 - fp32 to bfloat16 (vector)
    • Write to upper half of the destination vector
    • Do not modify the lower part

Neon

BFCVT    <Hd>, <Sn>          // Vd.H[0] = BF16(Vn.S[0]) 

BFCVTN   <Vd>.4H, <Vn>.4S    // bf16 <= fp32, clear upper ; 4 ops
BFCVTN   <Vd>.8H, <Vn>.4S    // bf16 <= fp32, clear upper ; 8 ops

BFCVTN2  <Vd>.4H, <Vn>.4S    // bf16 <= fp32, write upper ; 4 ops
BFCVTN2  <Vd>.8H, <Vn>.4S    // bf16 <= fp32, write upper ; 8 ops

Reference - NEON register map

See also

⚠️ **GitHub.com Fallback** ⚠️