online softmax - beyondnlp/nlp GitHub Wiki
https://github.com/karpathy/llama2.c/blob/b3c4b6c3c4bbff42e5211293280307019368ccb5/run.c#L197
์ ์ฝ๋๋ karpathy๊ฐ ๋ง๋ llama2.c ์ฝ๋ ์ค softmax๋ฅผ ๊ตฌํ๋ถ์ ๋๋ค.
void softmax(float* x, int size) {
// find max value (for numerical stability)
float max_val = x[0];
for (int i = 1; i < size; i++) {
if (x[i] > max_val) {
max_val = x[i];
}
}
// exp and sum
float sum = 0.0f;
for (int i = 0; i < size; i++) {
x[i] = expf(x[i] - max_val);
sum += x[i];
}
// normalize
for (int i = 0; i < size; i++) {
x[i] /= sum;
}
}
softmax๋ ๋ชจ๋ ์ซ์๋ฅผ ํ๋ฅ ๋ฒกํฐ๋ก ๋ณํํ๊ธฐ ์ํด ์ฌ์ฉ๋ฉ๋๋ค. ์ ๋ ฅ ๋ฒกํฐ๋ฅผ ๋ชจ๋ ์ง์๋ก ๊ณ์ฐ์ ํฉ๋๋ค. ๊ทธ๋ฐ ๋ค์ ์ ์ฒดํฉ์ผ๋ก ๊ฐ ๋ฒกํฐ๋ฅผ ๋๋๋ ๊ฒ์ผ๋ก ๊ฐ์ ๊ณ์ฐํ ์ ์์ต๋๋ค.
def EXP( b ):
alpha=2.7182818284590452353602874713526624977 # =e^1
return alpha ** b
def softmax_normal(x):
exp_x=[]
for nx in x:
e1=EXP( nx )
exp_x.append( e1 )
sum_exp_x = 1e-10
for i in exp_x: sum_exp_x = sum_exp_x + i
y=[]
for es in exp_x: y.append(es/sum_exp_x)
return y
softmax_normal()ํจ์๋ ๊ฐ์ฅ ๊ธฐ๋ณธ์ด ๋๋ ๊ตฌํ ๋ฐฉ๋ฒ์ ๋๋ค.
์๋๋ ์ง์ํจ์์ ํน์ฑ์ผ๋ก ์ธํด overflow๋ underflow๋ฅผ ๋ง๊ธฐ ์ํด
๋ถ๋ชจ๊ฐ ๋๋ sum_exp_x์ 1e-10์ ๊ฐ์ ์์ ๊ฐ์ ๋ํ๊ณ overflow๋ฅผ ๋ง๊ธฐ ์ํด ๋ฒกํฐ์ค ๊ฐ์ฅํฐ ๊ฐ์ ๋๋๊ฒ ๋ฉ๋๋ค. https://upload.wikimedia.org/wikipedia/commons/thumb/c/c6/Exp.svg/2560px-Exp.svg.png
def softmax_max(x):
max_one = -math.inf
for nx in x:
if nx > max_one : max_one = nx
exp_x=[]
for nx in x:
new_nx = nx-max_one
e1=EXP( new_nx )
exp_x.append( e1 )
sum_exp_x = 1e-10
for i in exp_x: sum_exp_x = sum_exp_x + i
print('[max]\tsum_exp_x', sum_exp_x)
y=[]
for es in exp_x: y.append(es/sum_exp_x)
return y
์์ ์ค๋ช ํ softmax_max์ ๊ฒฝ์ฐ max value๋ฅผ ์ฐพ๊ธฐ ์ํด ๋ฒกํฐ๋ฅผ ํ๋ฒ ์ํํ๊ฒ ๋๋๋ฐ ๊ณ์ฐ ์๋๋ฅผ ๋น ๋ฅด๊ฒ ํ๊ธฐ ์ํด ์๋์ ๊ฐ์ด for loop๋ฅผ ํ๋ ์ ๊ฑฐํ ๋ฒ์ ์ด flash attention๊ณผ ๊ฐ์ ์๋ ์ต์ ํ์ ์ฌ์ฉ์ด ๋ฉ๋๋ค.
def softmax_online(x):
N = len(x)
# Initialize variables
m = [-1] * N
d = [0] * N
y = [0] * N
old_max = -math.inf
new_max = -math.inf
exp_sum_x = 0;
for i in range(N):
new_max = get_max( old_max, x[i] )
a=EXP( old_max - new_max )
b=EXP( x[i] - new_max )
exp_sum_x = exp_sum_x * a + b
old_max = new_max
for i in range(N): y[i] = EXP( x[i] - old_max ) / exp_sum_x;
return y
- old_max - new_max๋ 0์ด์์ ๊ฐ์ ๊ฐ์ง๋ค.( new_max๋ old_max๋ณด๋ค ํญ์ ํฌ๊ธฐ ๋๋ฌธ์ )
- ์ฆ a๋ exp(0์ด์์๊ฐ)=1 ์ด์์ ๊ฐ์ด ๋๋ค.
- a๊ฐ 1์ด๋ฉด ์ด์ ๊ฐ์ b์ ๊ฐ๋ง exp_sum_x์ ๋ํด์ง๋ ์ํฉ
- exp(3) * exp(2) = exp(5)
[3, 5, 1]์ด ์ ๋ ฅ๋ฒกํฐ์ผ๋ ์ต๋๊ฐ์ 5๊ฐ ๋๋ค. ๋ชจ๋ ๋ฒกํฐ๋ ์ต๋๊ฐ์ ๋บ ํํ๋ก ์ง์ํจ์๋ฅผ ๊ฑฐ์น๊ฒ ๋๋ค. [-2, 0, -4]์ ํํ๋ก ๊ฐ ๋ฒกํฐ๊ฐ ์ง์ ํจ์๋ฅผ ๊ฑฐ์น๊ฐ์ ๋ชจ๋ ๋ํด ๋ถ๋ชจ๋ฅผ ์์ฑํ๋ค.
softmax_max์์ ์ฒซ๋ฒ์งธ ์์ดํ -2๋ฅผ ๊ธฐ์ค์ผ๋ก ์๊ฐํ๋ฉด exp(-2)=0.1353352832366127์ด ๋๋ค.
๊ทธ๋ฌ๋ฉด online_softmax()๋ฅผ ๊ธฐ์ค์ผ๋ก ๋ณด๋ฉด [3, 5, 1]์ '3'์ ๊ณ์ฐํ ๋๋ max๊ฐ์ด 3์ด๋ผ์ exp(0)์ด ๋๋ค. softmax_max์์ ์ฐจ์ด๋ -2๊ฐ ์๊ธฐ๋๋ฐ '5'๋ฅผ ๊ณ์ฐ์์ ์ sum=sum*a + b๋ก ๊ณ์ฐ, a=old_max - new_max๋ก ๊ณ์ฐ ์์์ ์ธ๊ธํ ์ฐจ์ด(-2) ์ฌ๊ธฐ์ ๊ณ์ฐ์ด ๋๋ค๊ณ ๋ณด๋ฉด ๋๋ค.
- a=old_max-new_max๋ ํญ์ 0์ด์์ด ๋๋ค
- ์ฆ exp(a)๋ ํญ์ 1 ์ด์์ ๊ฐ๊ฒ๋๋ค
- ๋ง์ฝ old_max์ new_max๊ฐ ๊ฐ์ผ๋ฉด(์ ๋ฐ์ดํธ๊ฐ ์์ผ๋ฉด) exp(a)๋ 1์ด ๋๋ค
- ํ์ฌ๊ฐ๋ง new_max๋ก ๊ณ์ฐํ๋ ค ๋ํ๋ค.
- ์ ๋ฐ์ดํธ๊ฐ ๋๋ฉด, new_max๊ฐ ์ปค์ง๋ฉด old_max์์ ์ฐจ์ด๋งํผ sum์ ๊ณฑํด์ง ๊ฐ์ด sum์ ๋ค์ ์ถ๊ฐ๋๋ ๋ก์ง์ด๋ค.