Transformer - BD-SEARCH/MLtutorial GitHub Wiki

Transformer?

  • 'Attention is all you need'์—์„œ ๋‚˜์˜จ ๋ชจ๋ธ
  • ๊ธฐ์กด์˜ seq2seq ๊ตฌ์กฐ์ธ ์ธ์ฝ”๋”-๋””์ฝ”๋”๋ฅผ ๋”ฐ๋ฅด๋ฉด์„œ๋„ attention๋งŒ์œผ๋กœ ๊ตฌํ˜„ํ•œ ๋ชจ๋ธ
  • ๊ธฐ์กด์˜ RNN, CNN ๊ตฌ์กฐ์˜ ํ•œ๊ณ„๋ฅผ ํƒˆํ”ผํ•˜๊ธฐ ์œ„ํ•จ
    • CNN
      • m*m์˜ ํ•„ํ„ฐ๋ฅผ ์Šฌ๋ผ์ด๋”ฉ
      • local์ ์ธ ๊ฒƒ์€ ์ž˜ ๋ฐ˜์˜ํ•˜์ง€๋งŒ ๋งจ ๋์ชฝ์— ๋Œ€ํ•œ ์˜๋ฏธ๋ฅผ ํŒŒ์•…ํ•˜๊ธฐ๊ฐ€ ์–ด๋ ค์›€
      • ht = f(xt, xt-k) + ... + f(xt, xt) + ... + f(xt, xt+k)
    • RNN
      • ์—ฐ์‡„์ ์ธ ์—ฐ์‚ฐ
      • ht = f(xt, f(xt-1, ... f(x3, f(x2, f(x1)))))
      • ์ด์ „์˜ ๋ฌธ์žฅ์˜ ํŠน์ง•์„ ๋ฐ˜์˜ํ•  ์ˆ˜ ์žˆ์Œ
      • ๋ฌธ์žฅ์ด ๊ธธ์–ด์งˆ ์ˆ˜๋ก ๋์—๊ฐ€ ๋ฐ˜์˜๋˜๊ธฐ ์–ด๋ ค์›€
        • momentum์ด ์ƒ๊ฒจ์„œ dependency๋•Œ๋ฌธ์— ๋์„ ๋ฐ˜์˜ํ•  ์ˆ˜ ์—†์Œ
    • RN(Relation Network)
      • ht = f(xt, x1) + ... + f(xt, xt-1) + f(xt, xt+1) + ... + f(xt, xT)
      • xt๋ฅผ ์ œ์™ธํ•˜๊ณ  ๋‚˜๋จธ์ง€๋ฅผ ์ „์ฒด ์—ฐ์‚ฐ
      • ๊ธธ์ด๊ฐ€ ๊ธธ์–ด์ง„๋‹ค๊ณ  ์–‘ ๋์„ ๋‹ค ์ž˜ ๋ฐ˜์˜ํ•  ์ˆ˜ ์žˆ์ง€๋งŒ ์—ฐ์‚ฐ์„ ๊ต‰์žฅํžˆ ๋งŽ์ด ํ•ด์•ผ ํ•จ

๊ตฌ์กฐ

image

  • ์ธํ’‹(๋ฒˆ์—ญํ•  ์–ธ์–ด) / ์•„์›ƒํ’‹(๋ฒˆ์—ญ๋œ ์–ธ์–ด) ๋ผ๋ฒจ์„ ์ฃผ๋ฉด ๊ทธ๊ฒƒ์„ ๋ฐ”ํƒ•์œผ๋กœ output probabilities๋ฅผ ๋‚ด๋†“๋Š”๋‹ค
  • ์ธ์ฝ”๋”๋ฅผ ๊ฑฐ์ณ์„œ ํŠน์ • ๊ฐ’์ด ๋‚˜์˜ค๊ฒŒ ๋˜๋ฉด ๋””์ฝ”๋”์—์„œ right shifted๋œ ๋ผ๋ฒจ์„ ํ†ต๊ณผํ•ด ํ•จ๊ผ๋œ๋‹ค
    • right shifted : soyoung is working at holiday (๊ฐ ๋‹จ์–ด๋Š” int)

1) encoder

input

  • (soyoung is working at holiday) ํ† ํฐ๋“ค

input embedding

  • image
  • ์ž„๋ฒ ๋”ฉ ๋ฒกํ„ฐ๊ฐ’์„ ์ทจํ•ด์„œ ๋„ฃ์–ด์ค€๋‹ค
  • ํ† ํฐ๋ณ„๋กœ ๋Š์–ด์„œ ์ „์ฒด ์‚ฌ์ด์ฆˆ๋Š” ๋ชจ๋ธ์‚ฌ์ด์ฆˆ๊ฐ€ ๋˜๋„๋ก ํ•œ๋‹ค
  • ๋ชจ๋ธ์‚ฌ์ด์ฆˆ๋ณด๋‹ค ์ž‘์œผ๋ฉด ํ† ํฐ์„ ์ฃผ๊ฒŒ ๋œ๋‹ค ๊ทธ ๋’ค์—๋Š” ๋นˆ ํ† ํฐ (๋ชจ๋ธ ์‚ฌ์ด์ฆˆ ๋งž์ถฐ์ฃผ๊ธฐ ์œ„ํ•ด)

position encoding

  • ๊ธฐ์กด์˜ RNN/CNN์ด ๊ฐ€์ง„ ์žฅ์  ์ค‘ ํ•˜๋‚˜ (์œ„์น˜์ •๋ณด๋ฅผ ์ €์žฅ)
  • ํ† ํฐ์˜ ํฌ์ง€์…˜ ๋ณ„๋กœ ํŠน์ • ๊ฐ’์„ ์‚ฐ์ถœ. ๊ทธ๊ฒƒ์„ ์ž„๋ฒ ๋”ฉ ๋œ ์ธํ’‹/์•„์›ƒํ’‹์— ๋”ํ•ด์ค€๋‹ค

multi-head attention

  • image
  • self attention ๊ธฐ๋ฒ•์ด ์‚ฌ์šฉ๋œ๋‹ค
  • ์ด 3๊ฐœ๊ฐ€ ์‚ฌ์šฉ(multi-head attention)
  • attention์˜ ์ข…๋ฅ˜
    • ์ธํ’‹์„ ์–ด๋””์„œ ๋ฐ›๋Š”์ง€์— ๋”ฐ๋ผ source-target attention / self attention์œผ๋กœ ๋‚˜๋‰œ๋‹ค
      • source target attention
        • query, key, value์˜ ์ธํ’‹์ด ์žˆ์„ ๋•Œ ์ฟผ๋ฆฌ๋Š” ํƒ€๊ฒŸ์—์„œ ๋ฐ›๊ณ  ํ‚ค, ๋ฐธ๋ฅ˜๋Š” ์†Œ์Šค์—์„œ ๋ฐ›์Œ
        • ์ฟผ๋ฆฌ๋ฅผ ๋‹ค๋ฅธ ํŠน์ •ํ•œ ๊ณณ์—์„œ ๋ฐ›์•„์˜ด
        • Q๋Š” ํŠน์ • ๋ฒกํ„ฐ
        • K, V๋Š” ๋™์ผํ•œ ๋ฒกํ„ฐ
      • self attention
        • ์ฟผ๋ฆฌ, ํ‚ค, ๋ฐธ๋ฅ˜ ๋ชจ๋‘ ์†Œ์Šค์—์„œ ๋ฐ›์Œ
        • ํŠน์ • ๋ฒกํ„ฐ๊ฐ€ ์žˆ์„ ๋•Œ ์ด๊ฒƒ์„ ํ‚ค์—๋„ ์ฃผ๊ณ  ๋ฐธ๋ฅ˜์—๋„ ์ฃผ๊ณ  ์ฟผ๋ฆฌ์—๋„ ์ฃผ๋Š” ๊ฒƒ
        • ์ฟผ๋ฆฌ ํ‚ค ๋ฐธ๋ฅ˜ ๋ชจ๋‘ ๋™์ผํ•œ ๋ฒกํ„ฐ
      • ์ฟผ๋ฆฌ ํ‚ค ๋ฐธ๋ฅ˜๋กœ ๋‚˜๋ˆ„๋Š” ์ด์œ ?
        • ํ‚ค ๋ฐธ๋ฅ˜๋ผ๊ณ  ํ•˜๋Š” ํŠน๋ณ„ํ•œ ์—ญํ• ์„ ๋‚˜๋ˆˆ ๋‹ค์Œ์— ํ•˜๋‚˜๋Š” ํ‚ค๋กœ ๋ณด๋‚ด๊ณ  ํ•˜๋‚˜๋Š” ๋ฐธ๋ฅ˜๋กœ ๋ณด๋‚ด๊ณ 
        • ํ•˜๋‚˜๋Š” ์–ดํƒ ์…˜ ๋ฐธ๋ฅ˜๋ฅผ ๋ฝ‘์•„๋‚ด๋Š”๋ฐ ์“ฐ์ด๊ณ  ํ•˜๋‚˜๋Š” ํžˆ๋“ ๋ ˆ์ด์–ด๋ฅผ ๋Œ€ํ‘œํ•˜๋Š”๋ฐ๋งŒ ์“ฐ๊ฒŒ ํ•˜์ž -> ๊ฒฐ๊ณผ ์ข‹์•˜๋‹ค
        • ์ด๊ฒƒ์ด ์–ดํ…์…˜ ๋„คํŠธ์›Œํฌ. ์–ดํ…์…˜์“ฐ๋ฉด ํ‚ค ๋ฐธ๋ฅ˜ ๋‚˜๋ˆ„๋Š” ๊ฒƒ์ด ๋ณดํŽธํ™”
    • ์—ฐ์‚ฐ ๋ฐฉ์‹์— ๋”ฐ๋ผ
      • additive attention
        • ์ฟผ๋ฆฌ ํ‚ค๋ฅผ FFN์— ํ†ต๊ณผ. ์ด๊ฒƒ์„ ๋ฐธ๋ฅ˜์— ํ†ต๊ณผ
        • FFN(์ฟผ๋ฆฌ;ํ‚ค) * ๋ฐธ๋ฅ˜
          • FFN(์ฟผ๋ฆฌ;ํ‚ค) : attention weight
        • attention weight๊ฐ€ ํ•œ๋ฒˆ์— ๋‚˜์˜ค๋Š” ๊ฒƒ์ด ์•„๋‹ˆ๊ณ , ๋ฐธ๋ฅ˜๊นŒ์ง€ ๊ณ„์‚ฐํ•œ ํ›„์—์•ผ ์–ดํ…์…˜ ์›จ์ดํŠธ๊ฐ€ ์‚ฐ์ถœ
        • attention weight ๋‚˜์˜ค๊ณ  -> value ๊ณฑํ•˜๊ณ  -> attention weight ๋‚˜์˜จ๋‹ค (์ˆœ์ฐจ์ )
      • dot-product attention
        • a(์ฟผ๋ฆฌ*(ํ‚ค)T)*๋ฐธ๋ฅ˜
          • a : attention weight
        • ํ•œ๋ฒˆ์— ๋งคํŠธ๋ฆฌ์Šค ๊ณฑ์„ ํ†ตํ•ด 1์ฐจ์ ์œผ๋กœ attention weight๊ฐ€ ๋‚˜์˜ค๊ฒŒ ๋œ๋‹ค
        • transformer์—์„œ ์‚ฌ์šฉํ•  ์–ดํ…์…˜์€ dot-product attention
    • multi-head attention
      • ์ฟผ๋ฆฌ ํ‚ค ๋ฐธ๋ฅ˜๋ฅผ ๋‚˜๋ˆˆ ๊ฒƒ์„ linear์„ ํ†ต๊ณผ
      • scale dot-product attention
        • scaling ๊ฐ’์„ ๊ณฑํ•ด์ค€ ๊ฒƒ
        • softmax๋ฅผ ์ทจํ•œ ๋‹ค์Œ์— v๋ฅผ ๊ณฑํ•œ ๊ฒƒ
        • ํ•˜๋Š” ์ด์œ 
          • Q, K๊ฐ€ dot-productํ•˜๋ฉด ๊ฐ’ ๋„ˆ๋ฌด ์ปค์ง€๋‹ˆ๊นŒ sqrt(dim)์„ ํ•ด์„œ ์ž‘๊ฒŒํ•œ๋‹ค
            • Q*K = N(0, dim)์ด ๋œ๋‹ค
      • ๋‚˜์˜จ ๊ฐ’๋“ค์„ concatํ•˜๊ฒŒ ๋˜๊ณ  ๋งˆ์ง€๋ง‰์— linearํ•จ์ˆ˜๋ฅผ ํ†ตํ•ด ๊ฐ’์„ ๋‚ด๋ณด๋‚ธ๋‹ค

add & norm

feed forward networks

  • ์ผ๋ฐ˜์ ์ธ denseํ•œ ๊ฒƒ์„ ์‚ฌ์šฉ

add & norm

2) decoder

output(right shifted)

output embedding

  • output embedding ์€ ํ•œ ์นธ์„ ๊ผญ ๋„์›Œ์ค€๋‹ค

positional encoding

masked multi-head attention

  • masked : output probability๋ฅผ ๋ฝ‘์•„๋‚ด๊ธฐ ์œ„ํ•ด์„œ ํ˜„์žฌ์˜ ํ† ํฐ๋“ค๊ณผ ์ด์ „์˜ ํ† ํฐ๋“ค์ด ์žˆ๋‹ค๊ณ  ํ•  ๋•Œ ๊ทธ ๋‹ค์Œ ํ† ํฐ๋“ค์„ ๋งˆ์Šคํ‚นํ•˜๋Š” ๊ณผ์ •
  • ์•ž์˜ ๋‹จ์–ด๋“ค์— ๋Œ€ํ•ด์„œ๋งŒ attention

add & norms

multi-head attention

  • encoder์˜ output๊ณผ ํ•จ๊ป˜ ์‚ฌ์šฉ

add & norms

feed forward networks

  • image
  • feed forward networks : position wise (1D conv)
  • 0๋ณด๋‹ค ์ž‘์œผ๋ฉด 0 ๋‚ด๋ณด๋‚ด๊ณ  ๊ทธ๋ณด๋‹ค ํฌ๋ฉด weight๋ฅผ ๋‚ด๋ณด๋‚ธ๋‹ค (ReLU์™€ ๋งค์šฐ ๋น„์Šท)

add & norm

3) output

linear

  • linear mapping
  • ์ด์ „๊นŒ์ง€๋Š” ์•„์›ƒํ’‹ ์ž„๋ฒ ๋”ฉ ๋œ ๊ฐ’๋“ค์ด ๋‹ค ๋ชจ๋ธ ์‚ฌ์ด์ฆˆ (512์ฐจ์›๊ฐ™์ด) ๋กœ ๋‚˜์™”๋Š”๋ฐ
  • linear์„ ๋“ค์–ด๊ฐ€๋ฉด์„œ ์ „์ฒด ์›Œ๋“œ ๋ฒกํ„ฐ์˜ ๊ฐ๊ฐ์˜ ์˜ˆ์ธก์ด ํ•„์š”ํ•˜๊ธฐ ๋•Œ๋ฌธ์— ๋ชจ๋ธ ์‚ฌ์ด์ฆˆ(512)์™€ ์›Œ๋“œ ์‚ฌ์ด์ฆˆ(๋งŒ๊ฐœ?)๋ฅผ ๋งž์ถฐ์ฃผ๋Š” ๊ณผ์ •์ด ํ•„์š”
  • ์ผ๋ฐ˜ fully connect layer
    • ์ •๋ณด ์••์ถ•
    • ์ฐจ์› ์ถ•์†Œ -> ์†๋„ ํ–ฅ์ƒ

softmax

  • ๋‹จ์–ด ๊ธธ์ด ๋ฒกํ„ฐ๋งŒํผ์˜ softmax๋ฅผ ๋ฝ‘์•„๋ƒ„

output probabilities

  • ๊ฐ ๋‹จ์–ด๋Š” s.m ํ™•๋ฅ ๊ฐ’์œผ๋กœ
  • 0๊ณผ 1์‚ฌ์ด์˜ ๊ฐ’์„ ๋‚ด๋ณด๋‚ธ๋‹ค
  • ๋ ˆ์ด๋ธ”๊ณผ ๋น„๊ตํ•˜๋Š”๋ฐ output-label์˜ ๊ฐ’์„ ์ค„์ด๋Š” ๊ฒƒ์ด ํŠธ๋ ˆ์ด๋‹์˜ ๋ชฉํ‘œ

encoder๊ณผ decoder๋Š” ์—ฌ๋Ÿฌ๊ฐœ๋ฅผ ์‚ฌ์šฉํ•ด์„œ ๋‹ค์ค‘์œผ๋กœ ์—ฐ์‚ฐํ•œ๋‹ค

image

  • ์‹œํ€€์Šค ๋ชจ๋ธ : ํŠธ๋ ˆ์ด๋‹๊ณผ ํ…Œ์ŠคํŠธ ๋ฐฉ๋ฒ•์ด ์ฐจ์ด๊ฐ€ ์žˆ๋‹ค
    • ํŠธ๋ ˆ์ด๋‹
      • ์Šคํƒ€ํŠธ ํ† ํฐ์„ ๋„ฃ์–ด์„œ ๋””์ฝ”๋” ํ†ตํ•ด์„œ ์•„์›ƒํ’‹1์ด ๋‚˜์˜ค๊ณ  ์ด์ „์— ์ค€๋น„ํ•œ ๋ผ๋ฒจ1์ด ๋””์ฝ”๋”๋กœ ๋“ค์–ด๊ฐ€์„œ ์•„์›ƒํ’‹2๊ฐ€ ๋‚˜์˜ค๊ฒŒ ๋œ๋‹ค
      • ๋ผ๋ฒจ2๊ฐ€ ๋“ค์–ด๊ฐ€์„œ ์•„์›ƒํ’‹3์ด ๋‚˜์˜จ๋‹ค
      • ์•„์›ƒํ’‹1๊ณผ ๋ผ๋ฒจ1, ์•„์›ƒํ’‹2์™€ ๋ผ๋ฒจ2์˜ ์ฐจ์ด๋ฅผ ์ค„์ด๋Š” ๊ฒƒ
    • ํ…Œ์ŠคํŒ…
      • ์Šคํƒ€ํŠธ ํ† ํฐ์ด ๋“ค์–ด๊ฐ€์„œ ์•„์›ƒํ’‹์ด ๋‚˜์˜ค๋ฉด ๊ทธ ์•„์›ƒํ’‹์„ ๊ทธ ๋‹ค์Œ์˜ ๋””์ฝ”๋”์˜ ์ธํ’‹์œผ๋กœ ๋„ฃ๊ฒŒ ๋œ๋‹ค

4) residual connection

  • ResNet์— ์ ์šฉ๋œ ๊ธฐ๋ฒ•
  • ์›๋ž˜์˜ x์™€ layer๋ฅผ ๊ฑฐ์นœ f(x)๋ฅผ ๋”ํ•จ
  • gradient vanishing ๋ฌธ์ œ ํ•ด๊ฒฐ (๋ ˆ์ด์–ด๋ฅผ ๋” ๊นŠ๊ฒŒ ์Œ“์„ ์ˆ˜ ์žˆ๊ฒŒ ํ•จ)

reference

โš ๏ธ **GitHub.com Fallback** โš ๏ธ