online softmax - beyondnlp/nlp GitHub Wiki

https://github.com/karpathy/llama2.c/blob/b3c4b6c3c4bbff42e5211293280307019368ccb5/run.c#L197

์œ„ ์ฝ”๋“œ๋Š” karpathy๊ฐ€ ๋งŒ๋“  llama2.c ์ฝ”๋“œ ์ค‘ softmax๋ฅผ ๊ตฌํ˜„๋ถ€์ž…๋‹ˆ๋‹ค.

void softmax(float* x, int size) {
    // find max value (for numerical stability)
    float max_val = x[0];
    for (int i = 1; i < size; i++) {
        if (x[i] > max_val) {
            max_val = x[i];
        }
    }
    // exp and sum
    float sum = 0.0f;
    for (int i = 0; i < size; i++) {
        x[i] = expf(x[i] - max_val);
        sum += x[i];
    }
    // normalize
    for (int i = 0; i < size; i++) {
        x[i] /= sum;
    }
}

softmax๋Š” ๋ชจ๋“  ์ˆซ์ž๋ฅผ ํ™•๋ฅ  ๋ฒกํ„ฐ๋กœ ๋ณ€ํ™˜ํ•˜๊ธฐ ์œ„ํ•ด ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค. ์ž…๋ ฅ ๋ฒกํ„ฐ๋ฅผ ๋ชจ๋‘ ์ง€์ˆ˜๋กœ ๊ณ„์‚ฐ์„ ํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๋Ÿฐ ๋‹ค์Œ ์ „์ฒดํ•ฉ์œผ๋กœ ๊ฐ ๋ฒกํ„ฐ๋ฅผ ๋‚˜๋ˆ„๋Š” ๊ฒƒ์œผ๋กœ ๊ฐ’์„ ๊ณ„์‚ฐํ• ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

def EXP( b ):
    alpha=2.7182818284590452353602874713526624977 # =e^1
    return alpha ** b

def softmax_normal(x):
    exp_x=[]
    for nx in x:
        e1=EXP( nx )
        exp_x.append( e1 )
    sum_exp_x = 1e-10
    for i in exp_x: sum_exp_x = sum_exp_x + i
    y=[]
    for es in exp_x: y.append(es/sum_exp_x)
    return y

softmax_normal()ํ•จ์ˆ˜๋Š” ๊ฐ€์žฅ ๊ธฐ๋ณธ์ด ๋˜๋Š” ๊ตฌํ˜„ ๋ฐฉ๋ฒ•์ž…๋‹ˆ๋‹ค.

์•„๋ž˜๋Š” ์ง€์ˆ˜ํ•จ์ˆ˜์˜ ํŠน์„ฑ์œผ๋กœ ์ธํ•ด overflow๋‚˜ underflow๋ฅผ ๋ง‰๊ธฐ ์œ„ํ•ด

๋ถ„๋ชจ๊ฐ€ ๋˜๋Š” sum_exp_x์— 1e-10์™€ ๊ฐ™์€ ์ž‘์€ ๊ฐ’์„ ๋”ํ•˜๊ณ  overflow๋ฅผ ๋ง‰๊ธฐ ์œ„ํ•ด ๋ฒกํ„ฐ์ค‘ ๊ฐ€์žฅํฐ ๊ฐ’์„ ๋‚˜๋ˆ„๊ฒŒ ๋ฉ๋‹ˆ๋‹ค. https://upload.wikimedia.org/wikipedia/commons/thumb/c/c6/Exp.svg/2560px-Exp.svg.png

def softmax_max(x):

    max_one = -math.inf
    for nx in x:
        if nx > max_one : max_one = nx

    exp_x=[]
    for nx in x:
        new_nx = nx-max_one
        e1=EXP( new_nx )
        exp_x.append( e1 )
    sum_exp_x = 1e-10
    for i in exp_x: sum_exp_x = sum_exp_x + i
    print('[max]\tsum_exp_x', sum_exp_x)
    y=[]
    for es in exp_x: y.append(es/sum_exp_x)
    return y

์œ„์— ์„ค๋ช…ํ•œ softmax_max์˜ ๊ฒฝ์šฐ max value๋ฅผ ์ฐพ๊ธฐ ์œ„ํ•ด ๋ฒกํ„ฐ๋ฅผ ํ•œ๋ฒˆ ์ˆœํšŒํ•˜๊ฒŒ ๋˜๋Š”๋ฐ ๊ณ„์‚ฐ ์†๋„๋ฅผ ๋น ๋ฅด๊ฒŒ ํ•˜๊ธฐ ์œ„ํ•ด ์•„๋ž˜์™€ ๊ฐ™์ด for loop๋ฅผ ํ•˜๋‚˜ ์ œ๊ฑฐํ•œ ๋ฒ„์ „์ด flash attention๊ณผ ๊ฐ™์€ ์†๋„ ์ตœ์ ํ™”์— ์‚ฌ์šฉ์ด ๋ฉ๋‹ˆ๋‹ค.

def softmax_online(x):
    N = len(x)

    # Initialize variables
    m = [-1] * N
    d = [0] * N
    y = [0] * N

    old_max = -math.inf
    new_max = -math.inf
    exp_sum_x = 0;
    for i in range(N):
        new_max = get_max( old_max, x[i] )
        a=EXP( old_max - new_max )
        b=EXP( x[i] - new_max )
        exp_sum_x = exp_sum_x * a + b
        old_max = new_max

    for i in range(N): y[i] = EXP( x[i] - old_max ) / exp_sum_x;

    return y
  • old_max - new_max๋Š” 0์ด์ƒ์˜ ๊ฐ’์„ ๊ฐ€์ง„๋‹ค.( new_max๋Š” old_max๋ณด๋‹ค ํ•ญ์ƒ ํฌ๊ธฐ ๋•Œ๋ฌธ์— )
  • ์ฆ‰ a๋Š” exp(0์ด์ƒ์˜๊ฐ’)=1 ์ด์ƒ์˜ ๊ฐ’์ด ๋œ๋‹ค.
  • a๊ฐ€ 1์ด๋ฉด ์ด์ „ ๊ฐ’์— b์˜ ๊ฐ’๋งŒ exp_sum_x์— ๋”ํ•ด์ง€๋Š” ์ƒํ™ฉ
  • exp(3) * exp(2) = exp(5)

[3, 5, 1]์ด ์ž…๋ ฅ๋ฒกํ„ฐ์ผ๋•Œ ์ตœ๋Œ€๊ฐ’์€ 5๊ฐ€ ๋œ๋‹ค. ๋ชจ๋“  ๋ฒกํ„ฐ๋Š” ์ตœ๋Œ€๊ฐ’์„ ๋บ€ ํ˜•ํƒœ๋กœ ์ง€์ˆ˜ํ•จ์ˆ˜๋ฅผ ๊ฑฐ์น˜๊ฒŒ ๋œ๋‹ค. [-2, 0, -4]์˜ ํ˜•ํƒœ๋กœ ๊ฐ ๋ฒกํ„ฐ๊ฐ€ ์ง€์ˆ˜ ํ•จ์ˆ˜๋ฅผ ๊ฑฐ์นœ๊ฐ’์„ ๋ชจ๋‘ ๋”ํ•ด ๋ถ„๋ชจ๋ฅผ ์™„์„ฑํ•œ๋‹ค.

softmax_max์—์„œ ์ฒซ๋ฒˆ์งธ ์•„์ดํ…œ -2๋ฅผ ๊ธฐ์ค€์œผ๋กœ ์ƒ๊ฐํ•˜๋ฉด exp(-2)=0.1353352832366127์ด ๋œ๋‹ค.

๊ทธ๋Ÿฌ๋ฉด online_softmax()๋ฅผ ๊ธฐ์ค€์œผ๋กœ ๋ณด๋ฉด [3, 5, 1]์— '3'์„ ๊ณ„์‚ฐํ• ๋•Œ๋Š” max๊ฐ’์ด 3์ด๋ผ์„œ exp(0)์ด ๋œ๋‹ค. softmax_max์™€์˜ ์ฐจ์ด๋Š” -2๊ฐ€ ์ƒ๊ธฐ๋Š”๋ฐ '5'๋ฅผ ๊ณ„์‚ฐ์‹œ์ ์— sum=sum*a + b๋กœ ๊ณ„์‚ฐ, a=old_max - new_max๋กœ ๊ณ„์‚ฐ ์œ„์—์„œ ์–ธ๊ธ‰ํ•œ ์ฐจ์ด(-2) ์—ฌ๊ธฐ์„œ ๊ณ„์‚ฐ์ด ๋œ๋‹ค๊ณ  ๋ณด๋ฉด ๋œ๋‹ค.

  • a=old_max-new_max๋Š” ํ•ญ์ƒ 0์ด์ƒ์ด ๋œ๋‹ค
  • ์ฆ‰ exp(a)๋Š” ํ•ญ์ƒ 1 ์ด์ƒ์„ ๊ฐ–๊ฒŒ๋œ๋‹ค
  • ๋งŒ์•ฝ old_max์™€ new_max๊ฐ€ ๊ฐ™์œผ๋ฉด(์—…๋ฐ์ดํŠธ๊ฐ€ ์—†์œผ๋ฉด) exp(a)๋Š” 1์ด ๋œ๋‹ค
  • ํ˜„์žฌ๊ฐ’๋งŒ new_max๋กœ ๊ณ„์‚ฐํ•˜๋ ค ๋”ํ•œ๋‹ค.
  • ์—…๋ฐ์ดํŠธ๊ฐ€ ๋˜๋ฉด, new_max๊ฐ€ ์ปค์ง€๋ฉด old_max์™€์˜ ์ฐจ์ด๋งŒํผ sum์— ๊ณฑํ•ด์ง„ ๊ฐ’์ด sum์— ๋‹ค์‹œ ์ถ”๊ฐ€๋˜๋Š” ๋กœ์ง์ด๋‹ค.