Attention Network - BD-SEARCH/MLtutorial GitHub Wiki

Attention Model

  • ๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ์ด ๋ฒกํ„ฐ Sequence ์ค‘์—์„œ ๊ฐ€์žฅ ์ค‘์š”ํ•œ ๋ฒกํ„ฐ์— ์ง‘์ค‘ํ•˜๋„๋ก ํ•˜๋Š” ๋ชจ๋ธ
    • State๋ฅผ ๊ณ ๋ คํ•˜์—ฌ ๊ฐ€์žฅ ์ค‘์š”๋„๊ฐ€ ๋†’์€ ๋ฒกํ„ฐ๋ฅผ ์ค‘์‹ฌ์œผ๋กœ ํ•˜๋‚˜์˜ ๋ฒกํ„ฐ๋กœ ์ •๋ฆฌํ•˜๋Š” ๋ชจ๋ธ.
    • Input
      • y_1, y_2, ..., y_n: ์ž…๋ ฅ ๋ฒกํ„ฐ. 1์ฐจ์› ๋ฐ์ดํ„ฐ์ธ Sequence ๋ฟ๋งŒ ์•„๋‹ˆ๋ผ 2์ฐจ์› ๋ฐ์ดํ„ฐ์ธ ์ด๋ฏธ์ง€ ๋“ฑ๋„ ๋ฐ›์„ ์ˆ˜ ์žˆ๋‹ค.
      • c: Context. ํ˜„์žฌ ์ƒํƒœ(๋ฌธ๋งฅ ๋“ฑ)์„ ๋‚˜ํƒ€๋‚ด๋Š” ๋ฒกํ„ฐ.
    • Output
      • z: context์™€ sequence๋ฅผ ๊ณ ๋ คํ•˜์—ฌ, y ๋ฒกํ„ฐ ์ค‘ ๊ฐ€์žฅ ์ค‘์š”ํ•œ ๋ฒกํ„ฐ๋ฅผ ์œ„์ฃผ๋กœ summary๋œ ๊ฐ’.
  • Attention Model์˜ Output์€ ์ค‘์š”ํ•œ ๋ฒกํ„ฐ๋ฅผ ์œ„์ฃผ๋กœ summary๋˜๊ธฐ ๋•Œ๋ฌธ์—, Sequence์˜ ์ค‘์š”ํ•œ ๋ถ€๋ถ„์— ์ง‘์ค‘ํ•œ๋‹ค๊ณ  ๋ณผ ์ˆ˜ ์žˆ๋‹ค.

Attention Model์˜ ๊ตฌ์กฐ์™€ ๋™์ž‘ ๋ฐฉ๋ฒ•

Attention Model์€ ๊ฐœ๋…์ ์œผ๋กœ ์•„๋ž˜์™€ ๊ฐ™์ด ๋™์ž‘ํ•œ๋‹ค.

  1. Input์œผ๋กœ ๋“ค์–ด์˜จ ๋ฒกํ„ฐ๋“ค์˜ ์ค‘์š”๋„/์œ ์‚ฌ๋„๋ฅผ, ํ˜„์žฌ state๋ฅผ ๊ณ ๋ คํ•˜์—ฌ ๊ตฌํ•œ๋‹ค.
  2. ๊ฐ๊ฐ์˜ ์ค‘์š”๋„๋ฅผ, ์ด ํ•ฉ์ด 1์ด ๋˜๋Š” ์ƒ๋Œ€๊ฐ’์œผ๋กœ ๋ฐ”๊พผ๋‹ค.
  3. ์ƒ๋Œ€๊ฐ’ ์ค‘์š”๋„๋ฅผ ๊ฐ€์ค‘์น˜๋กœ ๋ณด๊ณ , Sequence์— ์žˆ๋Š” ๋ฒกํ„ฐ๋“ค์„ ๊ฐ€์ค‘์น˜ํ•ฉํ•œ๋‹ค.

Attention Model์„ ๊ตฌํ˜„ํ•˜๋Š” ๋ฐฉ๋ฒ•์€ ์—ฌ๋Ÿฌ ๊ฐ€์ง€๊ฐ€ ์žˆ์œผ๋‚˜, ๊ทธ ์ค‘์—์„œ tanh ํ•จ์ˆ˜์™€ ๋‚ด์ ๊ณฑ์„ ์ด์šฉํ•œ Attention์€ ์•„๋ž˜์™€ ๊ฐ™์ด ๊ตฌํ˜„ํ•  ์ˆ˜ ์žˆ๋‹ค.

Attention Model์˜ ๊ธฐ๋ณธ ๊ตฌ์กฐ

  1. State C์— W_c ํ–‰๋ ฌ๋ฅผ ๋‚ด์ ํ•œ ๊ฐ’๊ณผ, ๊ฐ๊ฐ์˜ Sequence์˜ ๋ฒกํ„ฐ y_i์— W_y ํ–‰๋ ฌ๋ฅผ ๋‚ด์ ํ•œ ๊ฐ’์„ ๋”ํ•œ๋‹ค. ๊ทธ๋ฆฌ๊ณ  ์ด ๊ฐ’์„ tanh ํ•จ์ˆ˜์— ํ†ต๊ณผํ‚จ ๊ฐ’์„ m_i์ด๋ผ๊ณ  ํ•œ๋‹ค.
  2. m_i ๊ฐ’์„ Softmax ํ•จ์ˆ˜์— ํ†ต๊ณผ์‹œ์ผœ ํ™•๋ฅ ์„ ๊ตฌํ•œ ๊ฐ’์„ s_i๋ผ๊ณ  ํ•œ๋‹ค.
  3. s_i ๊ฐ’๊ณผ y_i ๊ฐ’์„ ๋‚ด์ ํ•œ๋‹ค. ์ด ๊ฐ’์„ ์ „๋ถ€ ํ•ฉ์นœ ๊ฒƒ์ด ์ถœ๋ ฅ๊ฐ’์ด ๋œ๋‹ค.

์ฃผ์˜์‚ฌํ•ญ

  • ์—ฌ๊ธฐ์„œ W_c, W_y ํ–‰๋ ฌ์€ ํ•™์Šต์„ ํ†ตํ•ด ๊ฐ’์„ ์ •ํ•œ๋‹ค.
  • 1์—์„œ ์‚ฌ์šฉํ•œ ๋ฐฉ๋ฒ•์€ y_i ๊ฐ’๊ณผ C ๊ฐ’์„ ํ•˜๋‚˜๋กœ ์„ž๋Š” ์ด์ƒ, ์–ด๋– ํ•œ ๋ฐฉ๋ฒ•์„ ์‚ฌ์šฉํ•ด๋„ ๋ฌด๋ฐฉํ•˜๋‹ค. (์˜ˆ: y_i์™€ C ๊ฐ’์„ ๋‚ด์ )

์ด์ „ RNN layer์˜ ์ถœ๋ ฅ๊ฐ’์ด ๋ชจ๋‘ ๋‚˜์™€์•ผ Attention Model์˜ ์ถœ๋ ฅ์„ ์ •ํ•  ์ˆ˜ ์žˆ๋‹ค. (์ด์ „ layer์˜ ๋ชจ๋“  Output์„ Input์œผ๋กœ ๋ฐ›๊ธฐ ๋•Œ๋ฌธ)

Encoder/Decoder์™€ Attention Machanism

  • ๊ฐ’์„ ์ž…๋ ฅ๋ฐ›์•„ ๋ฒกํ„ฐ๋กœ ๋งŒ๋“œ๋Š” encoder(์ธ์ฝ”๋”), ์ธ์ฝ”๋”๊ฐ€ ์ถœ๋ ฅํ•œ ๋ฒกํ„ฐ๋ฅผ ๋ฐ”ํƒ•์œผ๋กœ ์›ํ•˜๋Š” ๊ฒฐ๊ณผ๋ฅผ ์ถœ๋ ฅํ•˜๋Š” decoder(๋””์ฝ”๋”)๊ฐ€ ์žˆ๋‹ค๊ณ  ํ•˜์ž.
    • ์ผ๋ฐ˜์ ์ธ ๋ชจ๋ธ: ์ธ์ฝ”๋”์˜ ๋ชจ๋“  ์ถœ๋ ฅ ๋ฒกํ„ฐ๋ฅผ ๊ณจ๊ณ ๋ฃจ ๋ณด๊ณ  ๊ฒฐ๊ณผ๋ฅผ ์ถœ๋ ฅํ•œ๋‹ค.
    • Attention Mechanism: ์ธ์ฝ”๋”์˜ ์ถœ๋ ฅ ๋ฒกํ„ฐ ์ค‘ ์ค‘์š”ํ•œ ๋ฒกํ„ฐ์— ์ง‘์ค‘ํ•œ๋‹ค.
    • ์˜ˆ) "Artificial Intelligence"๋ฅผ "์ธ๊ณต ์ง€๋Šฅ"์œผ๋กœ ๋ฒˆ์—ญํ•œ๋‹ค๊ณ  ๊ฐ€์ •ํ•œ๋‹ค. ์ด ๋•Œ ๋ชจ๋ธ์ด '์ง€๋Šฅ'์„ ์˜ˆ์ธกํ•  ๋•Œ 'Intelligence'์— ์ฃผ๋ชฉํ•˜๊ฒŒ ๋œ๋‹ค.
  • Decoding ๊ณผ์ •์—์„œ ๋ถˆํ•„์š”ํ•œ ๋ฒกํ„ฐ๋ฅผ ๋ณด์ง€ ์•Š๊ธฐ ๋•Œ๋ฌธ์— ์„ฑ๋Šฅ์ด ํ–ฅ์ƒ๋œ๋‹ค.

RNN์„ ํ™œ์šฉํ•œ ๊ธฐ๊ณ„ ๋ฒˆ์—ญ์— ํ™œ์šฉ๋œ Attention Model

  • RNN Encoder/Decoder๋ฅผ ์‚ฌ์šฉํ•  ๋•Œ Attention Model์˜ Input
    • y_1, y_2, ..., y_n: ๋ณดํ†ต ์ด์ „ RNN layer๊ฐ€ ์ถœ๋ ฅํ•œ ๊ฐ’์˜ Sequence๋กœ ์‚ฌ์šฉ.
    • c: Attention Model์˜ Output์„ ์‚ฌ์šฉํ•˜๋Š” RNN ๋ชจ๋ธ์˜ ๋ฐ”๋กœ ์ง์ „ State๋ฅผ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด RNN ๋ชจ๋ธ์˜ ์„ธ ๋ฒˆ์งธ Output์„ ๊ณ„์‚ฐํ•œ๋‹ค๋ฉด, ๋‘ ๋ฒˆ์งธ Output์˜ hidden state ๊ฐ’์„ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋‹ค.

Attention Model์˜ Code Example

์ด์ „ layer์˜ ์ถœ๋ ฅ ๊ฐ’์ด 1์ฐจ์›์ธ ๊ฒฝ์šฐ

# ์ด์ „ layer์˜ ์ „์ฒด ๊ฐ’๋“ค์„ Input์œผ๋กœ ๋ฐ›๋Š”๋‹ค
inputs = Input(shape=(input_dims,))

# ๊ฐ๊ฐ์˜ ๊ฐ’์— ๋Œ€ํ•œ ์ค‘์š”๋„๋ฅผ ๊ตฌํ•œ๋‹ค.
attention_probs = Dense(input_dims, activation='softmax', name='attention_probs')(inputs)

# ๊ฐ๊ฐ์˜ ๊ฐ’ Matrix์™€ ์ค‘์š”๋„ Matrix๋ฅผ ํ–‰๋ ฌ๊ณฑํ•œ๋‹ค. ์ฆ‰, ์ž…๋ ฅ ๊ฐ’๋“ค์„ ์ค‘์š”๋„์— ๋”ฐ๋ผ ๊ฐ€์ค‘์น˜ํ•ฉํ•œ๋‹ค.
attention_mul = merge([inputs, attention_probs], output_shape=input_dims, name='attention_mul', mode='mul')
  • philipperemy/keras-attention-mechanism์—์„œ ์†Œ์Šค๋ฅผ ๊ฐ€์ ธ์™€ ์ฃผ์„์„ ๋‹ฌ์Œ.

  • ์ˆ˜ํ•™์ ์ธ ์˜๋ฏธ์˜ ํ–‰๋ ฌ(Matrix)๋ฅผ Dense Layer๋กœ ๋Œ€์ฒดํ•˜์—ฌ ๊ตฌํ˜„.

์ด์ „ layer์˜ ์ถœ๋ ฅ์ด 2์ฐจ์›์ธ ๊ฒฝ์šฐ

# inputs.shape = (batch_size, time_steps, input_dim)
input_dim = int(inputs.shape[2])

 # Input Matrix๋ฅผ Transposeํ•œ๋‹ค.
a = Permute((2, 1))(inputs)
a = Reshape((input_dim, TIME_STEPS))(a)

# ๊ฐ๊ฐ์˜ ๊ฐ’์— ๋”ฐ๋ผ ์ค‘์š”๋„๋ฅผ ๊ตฌํ•œ๋‹ค. (LSTM์˜ cell ๋ณ„)
# Matrix๋กœ ๋ณด๋ฉด, ์•„๋ž˜ layer์˜ ์ถœ๋ ฅ์€ ํ–‰: sequence index, ์—ด: input_dim
a = Dense(TIME_STEPS, activation='softmax')(a)

# ๋‹ค์‹œ Matrix๋ฅผ Transposeํ•ด์„œ Input Matrix์™€ ์ฐจ์›์˜ ์˜๋ฏธ๊ฐ€ ๊ฐ™๋„๋ก ์ˆ˜์ •
a_probs = Permute((2, 1), name='attention_vec')(a)
# a_probs - ํ–‰: input_dim, ์—ด: sequence index

# Input ๊ฐ’๋“ค์„ ์ค‘์š”๋„์— ๋”ฐ๋ผ ๊ฐ€์ค‘์น˜ํ•ฉํ•œ๋‹ค.
output_attention_mul = merge([inputs, a_probs], name='attention_mul', mode='mul')
  • philipperemy/keras-attention-mechanism์—์„œ ์†Œ์Šค๋ฅผ ๊ฐ€์ ธ์™€ ์ฃผ์„์„ ๋‹ฌ์Œ.

  • Input Matrix๋ฅผ Transposeํ•˜๋Š” ์ด์œ 

    • "๊ฐ๊ฐ์˜ Sequence์˜ ๋ฒกํ„ฐ y_i์— W_y Matrix๋ฅผ ๋‚ด์ ํ•œ ๊ฐ’์„ ๋”ํ•œ๋‹ค."๋ฅผ Dense layer๋กœ ๊ตฌํ˜„ํ•˜๊ธฐ ์œ„ํ•ด.
    • ๋‘ Matrix์˜ ์ฐจ์›์˜ ์˜๋ฏธ๋ฅผ ๋™์ผํ•˜๊ฒŒ ๋งž์ถ”๊ธฐ ์œ„ํ•ด
    • y_i๋ฅผ ์ „์ฒด๋กœ ๋ชจ์€ Y Matrix์™€ M_y Matrix๋ฅผ ๋‚ด์ ํ•œ๋‹ค๊ณ  ์ƒ๊ฐํ•˜๋ฉด ๋œ๋‹ค.
    • Dense layer ํ•˜๋‚˜๋ฅผ, ํ•™์Šตํ•ด์•ผ ํ•  Matrix๋กœ ์ƒ๊ฐํ•˜๋ฉด ์ดํ•ดํ•  ์ˆ˜ ์žˆ๋‹ค.

Attention Mechanism์˜ ํ™œ์šฉ

  • Output์ด Sequence ํ˜•ํƒœ๋ผ๋ฉด, Input์ด ์–ด๋–ค ํ˜•ํƒœ๋“  ๊ฐ„์—, Input์˜ ํŠน์ • ๋ถ€๋ถ„์„ ๊ฐ•์กฐํ•ด์„œ ๋ณด๋Š” ํ˜•ํƒœ๋กœ Attention Mechanism์„ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋‹ค.
  • Long-Term Dependency๋ฅผ ํ•ด๊ฒฐํ•  ๋•Œ, ํ˜น์€ ๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ์˜ ํ•™์Šต ์ƒํ™ฉ์„ ์ ๊ฒ€ํ•  ๋•Œ ํ™œ์šฉ ๊ฐ€๋Šฅํ•˜๋‹ค.
  • Transformer, BERT ๋“ฑ์—์„  CNN/RNN ์‚ฌ์šฉ ์—†์ด Attention๋งŒ์œผ๋กœ ๋ชจ๋ธ์ด ๊ตฌ์„ฑ๋˜์–ด ์žˆ๋‹ค.

RNN์˜ Long-Term Dependency ํ•ด๊ฒฐ

  • Long-Term Dependency(์žฅ๊ธฐ ์˜์กด์„ฑ)
    • ์ œ๊ณต๋œ ๋ฐ์ดํ„ฐ์™€ ๋ฐฐ์›Œ์•ผ ํ•  ์ •๋ณด์˜ ์ž…๋ ฅ ์ฐจ์ด(Gap)๊ฐ€ ํฐ ๊ฒฝ์šฐ ๋‘ ์ •๋ณด์˜ ๋ฌธ๋งฅ์„ ์—ฐ๊ฒฐํ•˜๊ธฐ ์–ด๋ ค์šด ํ˜„์ƒ.
    • LSTM(Long Short Term Memory Network) ๋“ฑ์„ ํ™œ์šฉํ•˜์—ฌ ํ•ด๊ฒฐํ•  ์ˆ˜ ์žˆ๋‹ค.
  • Attention Mechanism์„ ์ด์šฉํ•˜๋ฉด Sequence๊ฐ€ ๊ธธ๋”๋ผ๋„ ๊ทธ ์ค‘์—์„œ ์ค‘์š”ํ•œ ๋ฒกํ„ฐ์— ์ง‘์ค‘ํ•  ์ˆ˜ ์žˆ์œผ๋ฏ€๋กœ, Long-Term Dependency๋ฅผ ํ•ด๊ฒฐํ•  ์ˆ˜ ์žˆ๋‹ค.
  • Seq2Seq๋ฅผ ์ด์šฉํ•œ ์˜์–ด-๋…์ผ์–ด๊ฐ„ ๋ฒˆ์—ญ ๋ชจ๋ธ์„ ๊ฐ€์ง€๊ณ  ์‹คํ—˜์„ ํ•œ ๊ฒฐ๊ณผ, Attention Model์„ ์ ์šฉํ•˜๋ฉด ์„ฑ๋Šฅ์ด ์†Œํญ ์ƒ์Šนํ•œ ๊ฒฐ๊ณผ๋„ ์žˆ๋‹ค.

๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ์ด ์ž˜ ํ•™์Šต๋˜์—ˆ๋Š”์ง€ ํ™•์ธํ•  ๋•Œ

  • ๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ์ด ์ œ๋Œ€๋กœ ํ•™์Šต๋˜์ง€ ์•Š์•˜์œผ๋ฉด, ์œ„์™€ ๊ฐ™์ด ์—ฐ๊ด€๋œ ๋ฒกํ„ฐ ๊ฐ„์˜ ์ค‘์š”๋„๊ฐ€ ๋†’๊ฒŒ ๋‚˜์˜ค์ง€ ์•Š์„ ๊ฒƒ์ด๋‹ค. ์ด๋Ÿฐ ํ˜„์ƒ์„ ํ†ตํ•ด ๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ์ด ์ œ๋Œ€๋กœ ํ•™์Šต๋˜์—ˆ๋Š”์ง€ ํ™•์ธ ๊ฐ€๋Šฅํ•˜๋‹ค.
  • Output์—์„œ Sequence Input ์ค‘์—์„œ ์–ด๋Š ๊ฐ’์ด ์ค‘์š”ํ•˜๊ฒŒ ์‚ฌ์šฉ๋˜์—ˆ๋Š”์ง€ ์•Œ ์ˆ˜ ์žˆ๋‹ค.
    • ์œ„ ์ด๋ฏธ์ง€์—์„œ ๋ณด๋ฉด, ๋‘ ๋‹จ์–ด์˜ ๋œป์ด ๊ฐ™์€ ๋ถ€๋ถ„์ด ํ™œ์„ฑํ™”๊ฐ€ ๋˜์–ด ์žˆ๋Š” ๊ฒƒ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ๋‹ค.
    • ์˜ˆ๋ฅผ ๋“ค์–ด ๋ถˆ์–ด accord์™€ ์˜์–ด agreement๋Š” ๊ฐ™์€ ๋œป์ด๋ฏ€๋กœ Attention Model์˜ Output์ด ํฐ ๊ฒƒ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ๋‹ค.
    • ์ „ํ˜€ ์ƒ๊ด€ ์—†๋Š” ๋ถ€๋ถ„์˜ Attention์ด ๋†’๊ฒŒ ๋‚˜์˜ค๋ฉด ๋ชจ๋ธ์ด ์ž˜๋ชป ํ•™์Šต๋˜์—ˆ๋‹ค๊ณ  ๋ณผ ์ˆ˜ ์žˆ์„ ๊ฒƒ์ด๋‹ค.
  • ์ด๋ฏธ์ง€ ์บก์…”๋‹ ๋ชจ๋ธ์˜ ๊ฒฝ์šฐ, ๊ฐ ๋‹จ์–ด๋ณ„๋กœ ์ด๋ฏธ์ง€์˜ ์–ด๋–ค ๋ถ€๋ถ„ ๋•Œ๋ฌธ์— ๊ทธ ๋‹จ์–ด๊ฐ€ ์ƒ์„ฑ๋˜์—ˆ๋Š”์ง€ ํ™•์ธ ๊ฐ€๋Šฅํ•˜๋‹ค.

์ฐธ๊ณ ์ž๋ฃŒ

๋…ผ๋ฌธ

๊ธฐํƒ€