Transformer - BD-SEARCH/MLtutorial GitHub Wiki
- 'Attention is all you need'์์ ๋์จ ๋ชจ๋ธ
- ๊ธฐ์กด์ seq2seq ๊ตฌ์กฐ์ธ ์ธ์ฝ๋-๋์ฝ๋๋ฅผ ๋ฐ๋ฅด๋ฉด์๋ attention๋ง์ผ๋ก ๊ตฌํํ ๋ชจ๋ธ
- ๊ธฐ์กด์ RNN, CNN ๊ตฌ์กฐ์ ํ๊ณ๋ฅผ ํํผํ๊ธฐ ์ํจ
- CNN
- m*m์ ํํฐ๋ฅผ ์ฌ๋ผ์ด๋ฉ
- local์ ์ธ ๊ฒ์ ์ ๋ฐ์ํ์ง๋ง ๋งจ ๋์ชฝ์ ๋ํ ์๋ฏธ๋ฅผ ํ์ ํ๊ธฐ๊ฐ ์ด๋ ค์
- ht = f(xt, xt-k) + ... + f(xt, xt) + ... + f(xt, xt+k)
- RNN
- ์ฐ์์ ์ธ ์ฐ์ฐ
- ht = f(xt, f(xt-1, ... f(x3, f(x2, f(x1)))))
- ์ด์ ์ ๋ฌธ์ฅ์ ํน์ง์ ๋ฐ์ํ ์ ์์
- ๋ฌธ์ฅ์ด ๊ธธ์ด์ง ์๋ก ๋์๊ฐ ๋ฐ์๋๊ธฐ ์ด๋ ค์
- momentum์ด ์๊ฒจ์ dependency๋๋ฌธ์ ๋์ ๋ฐ์ํ ์ ์์
- RN(Relation Network)
- ht = f(xt, x1) + ... + f(xt, xt-1) + f(xt, xt+1) + ... + f(xt, xT)
- xt๋ฅผ ์ ์ธํ๊ณ ๋๋จธ์ง๋ฅผ ์ ์ฒด ์ฐ์ฐ
- ๊ธธ์ด๊ฐ ๊ธธ์ด์ง๋ค๊ณ ์ ๋์ ๋ค ์ ๋ฐ์ํ ์ ์์ง๋ง ์ฐ์ฐ์ ๊ต์ฅํ ๋ง์ด ํด์ผ ํจ
- CNN
- ์ธํ(๋ฒ์ญํ ์ธ์ด) / ์์ํ(๋ฒ์ญ๋ ์ธ์ด) ๋ผ๋ฒจ์ ์ฃผ๋ฉด ๊ทธ๊ฒ์ ๋ฐํ์ผ๋ก output probabilities๋ฅผ ๋ด๋๋๋ค
- ์ธ์ฝ๋๋ฅผ ๊ฑฐ์ณ์ ํน์ ๊ฐ์ด ๋์ค๊ฒ ๋๋ฉด ๋์ฝ๋์์ right shifted๋ ๋ผ๋ฒจ์ ํต๊ณผํด ํจ๊ผ๋๋ค
- right shifted : soyoung is working at holiday (๊ฐ ๋จ์ด๋ int)
- (soyoung is working at holiday) ํ ํฐ๋ค
- ์๋ฒ ๋ฉ ๋ฒกํฐ๊ฐ์ ์ทจํด์ ๋ฃ์ด์ค๋ค
- ํ ํฐ๋ณ๋ก ๋์ด์ ์ ์ฒด ์ฌ์ด์ฆ๋ ๋ชจ๋ธ์ฌ์ด์ฆ๊ฐ ๋๋๋ก ํ๋ค
- ๋ชจ๋ธ์ฌ์ด์ฆ๋ณด๋ค ์์ผ๋ฉด ํ ํฐ์ ์ฃผ๊ฒ ๋๋ค ๊ทธ ๋ค์๋ ๋น ํ ํฐ (๋ชจ๋ธ ์ฌ์ด์ฆ ๋ง์ถฐ์ฃผ๊ธฐ ์ํด)
- ๊ธฐ์กด์ RNN/CNN์ด ๊ฐ์ง ์ฅ์ ์ค ํ๋ (์์น์ ๋ณด๋ฅผ ์ ์ฅ)
- ํ ํฐ์ ํฌ์ง์ ๋ณ๋ก ํน์ ๊ฐ์ ์ฐ์ถ. ๊ทธ๊ฒ์ ์๋ฒ ๋ฉ ๋ ์ธํ/์์ํ์ ๋ํด์ค๋ค
- self attention ๊ธฐ๋ฒ์ด ์ฌ์ฉ๋๋ค
- ์ด 3๊ฐ๊ฐ ์ฌ์ฉ(multi-head attention)
- attention์ ์ข
๋ฅ
- ์ธํ์ ์ด๋์ ๋ฐ๋์ง์ ๋ฐ๋ผ source-target attention / self attention์ผ๋ก ๋๋๋ค
- source target attention
- query, key, value์ ์ธํ์ด ์์ ๋ ์ฟผ๋ฆฌ๋ ํ๊ฒ์์ ๋ฐ๊ณ ํค, ๋ฐธ๋ฅ๋ ์์ค์์ ๋ฐ์
- ์ฟผ๋ฆฌ๋ฅผ ๋ค๋ฅธ ํน์ ํ ๊ณณ์์ ๋ฐ์์ด
- Q๋ ํน์ ๋ฒกํฐ
- K, V๋ ๋์ผํ ๋ฒกํฐ
- self attention
- ์ฟผ๋ฆฌ, ํค, ๋ฐธ๋ฅ ๋ชจ๋ ์์ค์์ ๋ฐ์
- ํน์ ๋ฒกํฐ๊ฐ ์์ ๋ ์ด๊ฒ์ ํค์๋ ์ฃผ๊ณ ๋ฐธ๋ฅ์๋ ์ฃผ๊ณ ์ฟผ๋ฆฌ์๋ ์ฃผ๋ ๊ฒ
- ์ฟผ๋ฆฌ ํค ๋ฐธ๋ฅ ๋ชจ๋ ๋์ผํ ๋ฒกํฐ
- ์ฟผ๋ฆฌ ํค ๋ฐธ๋ฅ๋ก ๋๋๋ ์ด์ ?
- ํค ๋ฐธ๋ฅ๋ผ๊ณ ํ๋ ํน๋ณํ ์ญํ ์ ๋๋ ๋ค์์ ํ๋๋ ํค๋ก ๋ณด๋ด๊ณ ํ๋๋ ๋ฐธ๋ฅ๋ก ๋ณด๋ด๊ณ
- ํ๋๋ ์ดํ ์ ๋ฐธ๋ฅ๋ฅผ ๋ฝ์๋ด๋๋ฐ ์ฐ์ด๊ณ ํ๋๋ ํ๋ ๋ ์ด์ด๋ฅผ ๋ํํ๋๋ฐ๋ง ์ฐ๊ฒ ํ์ -> ๊ฒฐ๊ณผ ์ข์๋ค
- ์ด๊ฒ์ด ์ดํ ์ ๋คํธ์ํฌ. ์ดํ ์ ์ฐ๋ฉด ํค ๋ฐธ๋ฅ ๋๋๋ ๊ฒ์ด ๋ณดํธํ
- source target attention
- ์ฐ์ฐ ๋ฐฉ์์ ๋ฐ๋ผ
- additive attention
- ์ฟผ๋ฆฌ ํค๋ฅผ FFN์ ํต๊ณผ. ์ด๊ฒ์ ๋ฐธ๋ฅ์ ํต๊ณผ
- FFN(์ฟผ๋ฆฌ;ํค) * ๋ฐธ๋ฅ
- FFN(์ฟผ๋ฆฌ;ํค) : attention weight
- attention weight๊ฐ ํ๋ฒ์ ๋์ค๋ ๊ฒ์ด ์๋๊ณ , ๋ฐธ๋ฅ๊น์ง ๊ณ์ฐํ ํ์์ผ ์ดํ ์ ์จ์ดํธ๊ฐ ์ฐ์ถ
- attention weight ๋์ค๊ณ -> value ๊ณฑํ๊ณ -> attention weight ๋์จ๋ค (์์ฐจ์ )
- dot-product attention
- a(์ฟผ๋ฆฌ*(ํค)T)*๋ฐธ๋ฅ
- a : attention weight
- ํ๋ฒ์ ๋งคํธ๋ฆฌ์ค ๊ณฑ์ ํตํด 1์ฐจ์ ์ผ๋ก attention weight๊ฐ ๋์ค๊ฒ ๋๋ค
- transformer์์ ์ฌ์ฉํ ์ดํ ์ ์ dot-product attention
- a(์ฟผ๋ฆฌ*(ํค)T)*๋ฐธ๋ฅ
- additive attention
- multi-head attention
- ์ฟผ๋ฆฌ ํค ๋ฐธ๋ฅ๋ฅผ ๋๋ ๊ฒ์ linear์ ํต๊ณผ
- scale dot-product attention
- scaling ๊ฐ์ ๊ณฑํด์ค ๊ฒ
- softmax๋ฅผ ์ทจํ ๋ค์์ v๋ฅผ ๊ณฑํ ๊ฒ
- ํ๋ ์ด์
- Q, K๊ฐ dot-productํ๋ฉด ๊ฐ ๋๋ฌด ์ปค์ง๋๊น sqrt(dim)์ ํด์ ์๊ฒํ๋ค
- Q*K = N(0, dim)์ด ๋๋ค
- Q, K๊ฐ dot-productํ๋ฉด ๊ฐ ๋๋ฌด ์ปค์ง๋๊น sqrt(dim)์ ํด์ ์๊ฒํ๋ค
- ๋์จ ๊ฐ๋ค์ concatํ๊ฒ ๋๊ณ ๋ง์ง๋ง์ linearํจ์๋ฅผ ํตํด ๊ฐ์ ๋ด๋ณด๋ธ๋ค
- ์ธํ์ ์ด๋์ ๋ฐ๋์ง์ ๋ฐ๋ผ source-target attention / self attention์ผ๋ก ๋๋๋ค
- ์ผ๋ฐ์ ์ธ denseํ ๊ฒ์ ์ฌ์ฉ
- output embedding ์ ํ ์นธ์ ๊ผญ ๋์์ค๋ค
- masked : output probability๋ฅผ ๋ฝ์๋ด๊ธฐ ์ํด์ ํ์ฌ์ ํ ํฐ๋ค๊ณผ ์ด์ ์ ํ ํฐ๋ค์ด ์๋ค๊ณ ํ ๋ ๊ทธ ๋ค์ ํ ํฐ๋ค์ ๋ง์คํนํ๋ ๊ณผ์
- ์์ ๋จ์ด๋ค์ ๋ํด์๋ง attention
- encoder์ output๊ณผ ํจ๊ป ์ฌ์ฉ
- feed forward networks : position wise (1D conv)
- 0๋ณด๋ค ์์ผ๋ฉด 0 ๋ด๋ณด๋ด๊ณ ๊ทธ๋ณด๋ค ํฌ๋ฉด weight๋ฅผ ๋ด๋ณด๋ธ๋ค (ReLU์ ๋งค์ฐ ๋น์ท)
- linear mapping
- ์ด์ ๊น์ง๋ ์์ํ ์๋ฒ ๋ฉ ๋ ๊ฐ๋ค์ด ๋ค ๋ชจ๋ธ ์ฌ์ด์ฆ (512์ฐจ์๊ฐ์ด) ๋ก ๋์๋๋ฐ
- linear์ ๋ค์ด๊ฐ๋ฉด์ ์ ์ฒด ์๋ ๋ฒกํฐ์ ๊ฐ๊ฐ์ ์์ธก์ด ํ์ํ๊ธฐ ๋๋ฌธ์ ๋ชจ๋ธ ์ฌ์ด์ฆ(512)์ ์๋ ์ฌ์ด์ฆ(๋ง๊ฐ?)๋ฅผ ๋ง์ถฐ์ฃผ๋ ๊ณผ์ ์ด ํ์
- ์ผ๋ฐ fully connect layer
- ์ ๋ณด ์์ถ
- ์ฐจ์ ์ถ์ -> ์๋ ํฅ์
- ๋จ์ด ๊ธธ์ด ๋ฒกํฐ๋งํผ์ softmax๋ฅผ ๋ฝ์๋
- ๊ฐ ๋จ์ด๋ s.m ํ๋ฅ ๊ฐ์ผ๋ก
- 0๊ณผ 1์ฌ์ด์ ๊ฐ์ ๋ด๋ณด๋ธ๋ค
- ๋ ์ด๋ธ๊ณผ ๋น๊ตํ๋๋ฐ output-label์ ๊ฐ์ ์ค์ด๋ ๊ฒ์ด ํธ๋ ์ด๋์ ๋ชฉํ
encoder๊ณผ decoder๋ ์ฌ๋ฌ๊ฐ๋ฅผ ์ฌ์ฉํด์ ๋ค์ค์ผ๋ก ์ฐ์ฐํ๋ค
- ์ํ์ค ๋ชจ๋ธ : ํธ๋ ์ด๋๊ณผ ํ
์คํธ ๋ฐฉ๋ฒ์ด ์ฐจ์ด๊ฐ ์๋ค
- ํธ๋ ์ด๋
- ์คํํธ ํ ํฐ์ ๋ฃ์ด์ ๋์ฝ๋ ํตํด์ ์์ํ1์ด ๋์ค๊ณ ์ด์ ์ ์ค๋นํ ๋ผ๋ฒจ1์ด ๋์ฝ๋๋ก ๋ค์ด๊ฐ์ ์์ํ2๊ฐ ๋์ค๊ฒ ๋๋ค
- ๋ผ๋ฒจ2๊ฐ ๋ค์ด๊ฐ์ ์์ํ3์ด ๋์จ๋ค
- ์์ํ1๊ณผ ๋ผ๋ฒจ1, ์์ํ2์ ๋ผ๋ฒจ2์ ์ฐจ์ด๋ฅผ ์ค์ด๋ ๊ฒ
- ํ
์คํ
- ์คํํธ ํ ํฐ์ด ๋ค์ด๊ฐ์ ์์ํ์ด ๋์ค๋ฉด ๊ทธ ์์ํ์ ๊ทธ ๋ค์์ ๋์ฝ๋์ ์ธํ์ผ๋ก ๋ฃ๊ฒ ๋๋ค
- ํธ๋ ์ด๋
- ResNet์ ์ ์ฉ๋ ๊ธฐ๋ฒ
- ์๋์ x์ layer๋ฅผ ๊ฑฐ์น f(x)๋ฅผ ๋ํจ
- gradient vanishing ๋ฌธ์ ํด๊ฒฐ (๋ ์ด์ด๋ฅผ ๋ ๊น๊ฒ ์์ ์ ์๊ฒ ํจ)