attention mechanism - beyondnlp/nlp GitHub Wiki

์–ดํ…์…˜(attention)์ด๋ž€

vanilla rnn์—์„œ Vanishing Gradient Problem๋ฅผ ํ•ด๊ฒฐํ•˜๊ธฐ ์œ„ํ•ด ๋‚˜์˜จ lstm์—์„œ๋„ ๋ฌธ์ž์—ด์ด ๊ธธ์–ด์ง€๋ฉด ํšจ๊ณผ์ ์œผ๋กœ ์ •๋ณด๋ฅผ ์••์ถ•ํ•˜์ง€ ๋ชปํ•˜๋Š” ๋ฌธ์ œ๊ฐ€ ๋ฐœ์ƒํ•˜์—ฌ lstm์„ ์‚ฌ์šฉํ•œ seq2seq๋ชจ๋ธ์—์„œ๋„ ๋ฒˆ์—ญ์˜ ํ’ˆ์งˆ์ด ๋–จ์–ด์ง€๋Š” ๋ฌธ์ œ๊ฐ€ ๋ฐœ์ƒํ•œ๋‹ค. ๋ชจ๋“  ์ •๋ณด๋ฅผ ๋‹ค ์‚ฌ์šฉํ• ๋ ค๋Š”๋ฐ ๋ฌธ์ œ๊ฐ€ ์žˆ๋‹ค๊ณ ๋„ ๋ณผ์ˆ˜ ์žˆ๋‹ค. ์ด๋Ÿฐ ๋ฌธ์ œ๋ฅผ ํšจ๊ณผ์ ์œผ๋กœ ํ•ด๊ฒฐํ•˜๊ธฐ ์œ„ํ•ด ๋‚˜์˜จ ์•Œ๊ณ ๋ฆฌ์ฆ˜์ด attention mechanism์ด๋‹ค.( attention๋ง๊ณ ๋„ ์œ„์—์„œ ์–ธ๊ธ‰ํ•œ ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•˜๊ธฐ ์œ„ํ•ด์„œ bidirectional lstm์œผ๋กœ๋„ ์–ด๋А ์ •๋„ ํ•ด๊ฒฐ์ด ๋œ๋‹ค. )

์ด๋ฅผ ์ข€ ๋” ์ดํ•ด๊ฐ€ ์‰ฝ๊ฒŒ ์„ค๋ช…ํ•˜๋ฉด ๋‚˜๋Š” ๋งฅ์ฃผ๊ฐ€ ์ข‹๋‹ค๋Š” ์˜๋ฏธ์˜ ๋…์ผ์–ด ๋ฌธ์žฅ(Ich mochte ein bier)๊ณผ ์˜์–ด ๋ฌธ์žฅ(Iโ€™d like a beer)์„ seq2seq๋ฅผ ํ†ตํ•ด ๋ฒˆ์—ญํ•œ๋‹ค๊ณ  ํ•  ๋•Œ encoder2decoder ์‚ฌ๋žŒ์€ ์ง๊ด€์ ์œผ๋กœ beer์€ bier์—๋งŒ ์˜ํ–ฅ์„ ๋ฐ›๋Š” ๋‹ค๋Š” ๊ฒƒ์„ ์•Œ ์ˆ˜ ์žˆ๊ณ  ์ด๋ฅผ ์•Œ๊ณ ๋ฆฌ์ฆ˜ํ™” ํ•œ ๊ฒƒ์ด๋‹ค. beer๋ฅผ ์˜ˆ์ธกํ•  ๋•Œ bier์ด์™ธ์— ๊ฒƒ์€ ๋ณ„๋‹ค๋ฅธ ์‹ ๊ฒฝ์„ ์“ธ ํ•„์š”๊ฐ€ ์—†๊ณ  ์˜คํžˆ๋ ค ์„ฑ๋Šฅ์„ ์ €ํ•˜์‹œํ‚ค๋Š” ์š”์ธ์ด ๋˜๊ธฐ ๋•Œ๋ฌธ์— bier์— ์ง‘์ค‘ํ•˜๊ฒ ๋‹ค๋Š” ์˜๋ฏธ์ด๋‹ค. ๊ทธ๋Ÿผ ์–ด๋–ค ์ •๋ณด์— ๊ธฐ๋ฐ˜ํ•˜์—ฌ beer๋ฅผ ์˜ˆ์ธกํ•  ๋•Œ๋Š” bier์— ์˜ํ–ฅ๋ฐ›๋Š” ๋‹ค๋Š” ๊ฒƒ์„ ์•Œ ์ˆ˜ ์žˆ์„๊นŒ? "encoder์—์„œ bier๋ฅผ ์ž…๋ ฅ์œผ๋กœ ํ•ด์„œ ๋งŒ๋“  ์ถœ๋ ฅ ๋ฒกํ„ฐ์™€ decoder์—์„œ beer๋ฅผ ๋งŒ๋“ค ๋•Œ ์‚ฌ์šฉํ•˜๋Š” ๋ฒกํ„ฐ๊ฐ€ ์„œ๋กœ ์œ ์‚ฌํ•  ๊ฒƒ์ด๋‹ค"๋ผ๋Š” ๊ฐ€์ •์—์„œ ์ถœ๋ฐœํ•œ๋‹ค ( ์ด ๋ถ€๋ถ„์„ ๋ณด๋‹ค ๋ณด๋‹ˆ SMT์—์„œ word alignment์™€ ์œ ์‚ฌํ•œ ๊ฐœ๋…์ด ์•„๋‹Œ๊ฐ€ ํ•˜๋Š” ์ƒ๊ฐ์ด ๋“ค์—ˆ๋‹ค )

how to implement attention

๊ทธ๋Ÿฌ๋ฉด ์–ด๋–ป๊ฒŒ ์ด๋ฅผ ๊ตฌํ˜„ํ•  ๊ฒƒ์ธ๊ฐ€ attention attention

  • H(t)๋Š” ๋””์ฝ”๋”์˜ ๋ฒกํ„ฐ์ด๊ณ  H(s)๋Š” ์ธ์ฝ”๋”์˜ ๋ฒกํ„ฐ์ด๋‹ค ์œ„ ๊ณต์‹์ฒ˜๋Ÿผ
  1. Attention Weights๋Š” H(t)์™€ ๋ชจ๋“  H(s)๋ฅผ ๋‚ด์ ํ•˜์—ฌ ๋‚˜์˜จ ๋ฒกํ„ฐ๋ฅผ softmax๋ฅผ ์ทจํ•ด ํ™•๋ฅ ์„ ๊ตฌํ•œ๋‹ค.
  2. C(t)=Context(t)๋Š” ๋ชจ๋“  H(s)์™€ Attention Weights๋ฅผ ๊ณฑํ•ด์„œ ๋”ํ•œ๋‹ค. ( ์Šค์นผ๋ผ ๊ฐ’์ด ๋‚˜์˜ฌ ๋“ฏ )
  3. attention(H(t))์€ Context(t)์™€ input(H(t))์„ concatenationํ•˜๊ณ  tanh๋ฅผ ํ†ต๊ณผํ•œ ๊ฐ’์ด๋‹ค. ( ๋ฏธ๋ถ„ ๊ฐ€๋Šฅํ•˜๊ธฐ ๋•Œ๋ฌธ์— ํ•™์Šต์ด ๊ฐ€๋Šฅํ•  ๋“ฏ )

์—ฌ๊ธฐ์„œ stanford nlp์˜ ์ž๋ฃŒ๋ฅผ ์ฐธ๊ณ ํ•˜๋ฉด calcuate attention weight keras code๋ฅผ ๋ณด๋ฉด ์•„๋ž˜์™€ ๊ฐ™์ด ๊ตฌํ˜„๋ผ ์žˆ๋‹ค.

 inputs = Input(shape=(input_dim,))
 #ATTENTION PART STARTS HERE
 attention_probs = Dense(input_dim, activation='softmax', name='attention_vec')(inputs)
 attention_mul = merge([inputs, attention_probs], output_shape=32, name='attention_mul', mode='mul')
 #ATTENTION PART FINISHES HERE
 attention_mul = Dense(64)(attention_mul)
 output = Dense(1, activation='sigmoid')(attention_mul)
 model = Model(input=[inputs], output=output)

input_dim * input_dim ์˜ matrix๊ฐ€ ์ƒ์„ฑ์ด ๋˜๊ณ  activation func์„ softmax๋ฅผ ์ทจํ•˜๊ณ  ์ด maxtrix๋ฅผ attention_probs๋กœ ๋ช…์นญํ•œ๋‹ค. ์ด์–ด์„œ inputs์™€ attention_probs๋ฅผ merge(ํ–‰๋ ฌ๊ณฑ์…ˆ)ํ•˜์—ฌ attention_mul์„ ์ƒ์„ฑํ•˜๊ณ  ์ด๋ฅผ ๋‹ค์‹œ Fully Connected Layer๋กœ 64์ฐจ์› output์„ ๋งŒ๋“ค๊ณ  ์—ฌ๊ธฐ์— sigmoid๋ฅผ ๊ฑฐ์ณ 1์ฐจ์› output์„ ๋งŒ๋“ ๋‹ค.

element-wise

์–ดํ…์…˜์—์„œ ํ•™์Šต์ด ๋˜๋Š” ๋ถ€๋ถ„

๋งˆ์ง€๋ง‰์œผ๋กœ "A Brief Overview of Attention Mechanism" ์˜ ๋ฏธ๋””์—„ ๋ธ”๋กœ๊ทธ ๊ธ€ ๋งˆ์ง€๋ง‰์— ์•„๋ž˜์™€ ๊ฐ™์€ ๋ฌธ์žฅ์ด ์žˆ๋‹ค. There are many variants in the cutting-edge researches, and they basically differ in the choice of score function and attention function, or of soft attention and hard attention (whether differentiable). But basic concepts are all the same.

์ตœ์ฒจ๋‹จ์˜ ์—ฐ๊ตฌ ๋ถ„์•ผ์—๋Š” ๋งŽ์€ ๋ณ€ํ˜•๋“ค์ด ์žˆ์Šต๋‹ˆ๋‹ค. ๊ธฐ๋ณธ์ ์œผ๋กœ scoreํ•จ์ˆ˜์™€ attention ํ•จ์ˆ˜์˜ ์„ ํƒ์ด ๋‹ค๋ฅด๊ณ  ๋˜๋Š” soft attention๊ณผ hard attention( ๋ฏธ๋ถ„์ด ๊ฐ€๋Šฅํ•œ์ง€ )์ด ๋‹ค๋ฆ…๋‹ˆ๋‹ค. ๊ทธ๋Ÿฐ๋ฐ ๊ธฐ๋ณธ์ ์ธ ๊ฐœ๋…์€ ๋ชจ๋‘ ๋™์ผํ•ฉ๋‹ˆ๋‹ค.

  • score( Hi, ^Hi )๋Š” fully connected network๋กœ ํ•™์Šต

attention ๊ณ„์‚ฐ ๋ฐฉ์‹