Attention Network - BD-SEARCH/MLtutorial GitHub Wiki
- Natural Language Inference, Sentence representation and Attention Mechanism ๋ ผ๋ฌธ์ ์์ฝ ๋ฐ ์ ๋ฆฌํ๊ณ ์ถ๊ฐ์ ์ธ ๋ด์ฉ์ ๋ง๋ถ์.
Attention Model
- ๋ฅ๋ฌ๋ ๋ชจ๋ธ์ด ๋ฒกํฐ Sequence ์ค์์ ๊ฐ์ฅ ์ค์ํ ๋ฒกํฐ์ ์ง์คํ๋๋ก ํ๋ ๋ชจ๋ธ
- State๋ฅผ ๊ณ ๋ คํ์ฌ ๊ฐ์ฅ ์ค์๋๊ฐ ๋์ ๋ฒกํฐ๋ฅผ ์ค์ฌ์ผ๋ก ํ๋์ ๋ฒกํฐ๋ก ์ ๋ฆฌํ๋ ๋ชจ๋ธ.
- Input
- y_1, y_2, ..., y_n: ์ ๋ ฅ ๋ฒกํฐ. 1์ฐจ์ ๋ฐ์ดํฐ์ธ Sequence ๋ฟ๋ง ์๋๋ผ 2์ฐจ์ ๋ฐ์ดํฐ์ธ ์ด๋ฏธ์ง ๋ฑ๋ ๋ฐ์ ์ ์๋ค.
- c: Context. ํ์ฌ ์ํ(๋ฌธ๋งฅ ๋ฑ)์ ๋ํ๋ด๋ ๋ฒกํฐ.
- Output
- z: context์ sequence๋ฅผ ๊ณ ๋ คํ์ฌ, y ๋ฒกํฐ ์ค ๊ฐ์ฅ ์ค์ํ ๋ฒกํฐ๋ฅผ ์์ฃผ๋ก summary๋ ๊ฐ.
- Attention Model์ Output์ ์ค์ํ ๋ฒกํฐ๋ฅผ ์์ฃผ๋ก summary๋๊ธฐ ๋๋ฌธ์, Sequence์ ์ค์ํ ๋ถ๋ถ์ ์ง์คํ๋ค๊ณ ๋ณผ ์ ์๋ค.
Attention Model์ ๊ตฌ์กฐ์ ๋์ ๋ฐฉ๋ฒ
Attention Model์ ๊ฐ๋ ์ ์ผ๋ก ์๋์ ๊ฐ์ด ๋์ํ๋ค.
- Input์ผ๋ก ๋ค์ด์จ ๋ฒกํฐ๋ค์ ์ค์๋/์ ์ฌ๋๋ฅผ, ํ์ฌ state๋ฅผ ๊ณ ๋ คํ์ฌ ๊ตฌํ๋ค.
- ๊ฐ๊ฐ์ ์ค์๋๋ฅผ, ์ด ํฉ์ด 1์ด ๋๋ ์๋๊ฐ์ผ๋ก ๋ฐ๊พผ๋ค.
- ์๋๊ฐ ์ค์๋๋ฅผ ๊ฐ์ค์น๋ก ๋ณด๊ณ , Sequence์ ์๋ ๋ฒกํฐ๋ค์ ๊ฐ์ค์นํฉํ๋ค.
Attention Model์ ๊ตฌํํ๋ ๋ฐฉ๋ฒ์ ์ฌ๋ฌ ๊ฐ์ง๊ฐ ์์ผ๋, ๊ทธ ์ค์์ tanh ํจ์์ ๋ด์ ๊ณฑ์ ์ด์ฉํ Attention์ ์๋์ ๊ฐ์ด ๊ตฌํํ ์ ์๋ค.
- State C์ W_c ํ๋ ฌ๋ฅผ ๋ด์ ํ ๊ฐ๊ณผ, ๊ฐ๊ฐ์ Sequence์ ๋ฒกํฐ y_i์ W_y ํ๋ ฌ๋ฅผ ๋ด์ ํ ๊ฐ์ ๋ํ๋ค. ๊ทธ๋ฆฌ๊ณ ์ด ๊ฐ์ tanh ํจ์์ ํต๊ณผํจ ๊ฐ์ m_i์ด๋ผ๊ณ ํ๋ค.
- m_i ๊ฐ์ Softmax ํจ์์ ํต๊ณผ์์ผ ํ๋ฅ ์ ๊ตฌํ ๊ฐ์ s_i๋ผ๊ณ ํ๋ค.
- s_i ๊ฐ๊ณผ y_i ๊ฐ์ ๋ด์ ํ๋ค. ์ด ๊ฐ์ ์ ๋ถ ํฉ์น ๊ฒ์ด ์ถ๋ ฅ๊ฐ์ด ๋๋ค.
์ฃผ์์ฌํญ
- ์ฌ๊ธฐ์ W_c, W_y ํ๋ ฌ์ ํ์ต์ ํตํด ๊ฐ์ ์ ํ๋ค.
- 1์์ ์ฌ์ฉํ ๋ฐฉ๋ฒ์ y_i ๊ฐ๊ณผ C ๊ฐ์ ํ๋๋ก ์๋ ์ด์, ์ด๋ ํ ๋ฐฉ๋ฒ์ ์ฌ์ฉํด๋ ๋ฌด๋ฐฉํ๋ค. (์: y_i์ C ๊ฐ์ ๋ด์ )
์ด์ RNN layer์ ์ถ๋ ฅ๊ฐ์ด ๋ชจ๋ ๋์์ผ Attention Model์ ์ถ๋ ฅ์ ์ ํ ์ ์๋ค. (์ด์ layer์ ๋ชจ๋ Output์ Input์ผ๋ก ๋ฐ๊ธฐ ๋๋ฌธ)
Encoder/Decoder์ Attention Machanism
- ๊ฐ์ ์
๋ ฅ๋ฐ์ ๋ฒกํฐ๋ก ๋ง๋๋ encoder(์ธ์ฝ๋), ์ธ์ฝ๋๊ฐ ์ถ๋ ฅํ ๋ฒกํฐ๋ฅผ ๋ฐํ์ผ๋ก ์ํ๋ ๊ฒฐ๊ณผ๋ฅผ ์ถ๋ ฅํ๋ decoder(๋์ฝ๋)๊ฐ ์๋ค๊ณ ํ์.
- ์ผ๋ฐ์ ์ธ ๋ชจ๋ธ: ์ธ์ฝ๋์ ๋ชจ๋ ์ถ๋ ฅ ๋ฒกํฐ๋ฅผ ๊ณจ๊ณ ๋ฃจ ๋ณด๊ณ ๊ฒฐ๊ณผ๋ฅผ ์ถ๋ ฅํ๋ค.
- Attention Mechanism: ์ธ์ฝ๋์ ์ถ๋ ฅ ๋ฒกํฐ ์ค ์ค์ํ ๋ฒกํฐ์ ์ง์คํ๋ค.
- ์) "Artificial Intelligence"๋ฅผ "์ธ๊ณต ์ง๋ฅ"์ผ๋ก ๋ฒ์ญํ๋ค๊ณ ๊ฐ์ ํ๋ค. ์ด ๋ ๋ชจ๋ธ์ด '์ง๋ฅ'์ ์์ธกํ ๋ 'Intelligence'์ ์ฃผ๋ชฉํ๊ฒ ๋๋ค.
- Decoding ๊ณผ์ ์์ ๋ถํ์ํ ๋ฒกํฐ๋ฅผ ๋ณด์ง ์๊ธฐ ๋๋ฌธ์ ์ฑ๋ฅ์ด ํฅ์๋๋ค.
- RNN Encoder/Decoder๋ฅผ ์ฌ์ฉํ ๋ Attention Model์ Input
- y_1, y_2, ..., y_n: ๋ณดํต ์ด์ RNN layer๊ฐ ์ถ๋ ฅํ ๊ฐ์ Sequence๋ก ์ฌ์ฉ.
- c: Attention Model์ Output์ ์ฌ์ฉํ๋ RNN ๋ชจ๋ธ์ ๋ฐ๋ก ์ง์ State๋ฅผ ์ฌ์ฉํ ์ ์๋ค. ์๋ฅผ ๋ค์ด RNN ๋ชจ๋ธ์ ์ธ ๋ฒ์งธ Output์ ๊ณ์ฐํ๋ค๋ฉด, ๋ ๋ฒ์งธ Output์ hidden state ๊ฐ์ ์ฌ์ฉํ ์ ์๋ค.
Attention Model์ Code Example
์ด์ layer์ ์ถ๋ ฅ ๊ฐ์ด 1์ฐจ์์ธ ๊ฒฝ์ฐ
# ์ด์ layer์ ์ ์ฒด ๊ฐ๋ค์ Input์ผ๋ก ๋ฐ๋๋ค
inputs = Input(shape=(input_dims,))
# ๊ฐ๊ฐ์ ๊ฐ์ ๋ํ ์ค์๋๋ฅผ ๊ตฌํ๋ค.
attention_probs = Dense(input_dims, activation='softmax', name='attention_probs')(inputs)
# ๊ฐ๊ฐ์ ๊ฐ Matrix์ ์ค์๋ Matrix๋ฅผ ํ๋ ฌ๊ณฑํ๋ค. ์ฆ, ์
๋ ฅ ๊ฐ๋ค์ ์ค์๋์ ๋ฐ๋ผ ๊ฐ์ค์นํฉํ๋ค.
attention_mul = merge([inputs, attention_probs], output_shape=input_dims, name='attention_mul', mode='mul')
-
philipperemy/keras-attention-mechanism์์ ์์ค๋ฅผ ๊ฐ์ ธ์ ์ฃผ์์ ๋ฌ์.
-
์ํ์ ์ธ ์๋ฏธ์ ํ๋ ฌ(Matrix)๋ฅผ Dense Layer๋ก ๋์ฒดํ์ฌ ๊ตฌํ.
์ด์ layer์ ์ถ๋ ฅ์ด 2์ฐจ์์ธ ๊ฒฝ์ฐ
# inputs.shape = (batch_size, time_steps, input_dim)
input_dim = int(inputs.shape[2])
# Input Matrix๋ฅผ Transposeํ๋ค.
a = Permute((2, 1))(inputs)
a = Reshape((input_dim, TIME_STEPS))(a)
# ๊ฐ๊ฐ์ ๊ฐ์ ๋ฐ๋ผ ์ค์๋๋ฅผ ๊ตฌํ๋ค. (LSTM์ cell ๋ณ)
# Matrix๋ก ๋ณด๋ฉด, ์๋ layer์ ์ถ๋ ฅ์ ํ: sequence index, ์ด: input_dim
a = Dense(TIME_STEPS, activation='softmax')(a)
# ๋ค์ Matrix๋ฅผ Transposeํด์ Input Matrix์ ์ฐจ์์ ์๋ฏธ๊ฐ ๊ฐ๋๋ก ์์
a_probs = Permute((2, 1), name='attention_vec')(a)
# a_probs - ํ: input_dim, ์ด: sequence index
# Input ๊ฐ๋ค์ ์ค์๋์ ๋ฐ๋ผ ๊ฐ์ค์นํฉํ๋ค.
output_attention_mul = merge([inputs, a_probs], name='attention_mul', mode='mul')
-
philipperemy/keras-attention-mechanism์์ ์์ค๋ฅผ ๊ฐ์ ธ์ ์ฃผ์์ ๋ฌ์.
-
Input Matrix๋ฅผ Transposeํ๋ ์ด์
- "๊ฐ๊ฐ์ Sequence์ ๋ฒกํฐ y_i์ W_y Matrix๋ฅผ ๋ด์ ํ ๊ฐ์ ๋ํ๋ค."๋ฅผ Dense layer๋ก ๊ตฌํํ๊ธฐ ์ํด.
- ๋ Matrix์ ์ฐจ์์ ์๋ฏธ๋ฅผ ๋์ผํ๊ฒ ๋ง์ถ๊ธฐ ์ํด
- y_i๋ฅผ ์ ์ฒด๋ก ๋ชจ์ Y Matrix์ M_y Matrix๋ฅผ ๋ด์ ํ๋ค๊ณ ์๊ฐํ๋ฉด ๋๋ค.
- Dense layer ํ๋๋ฅผ, ํ์ตํด์ผ ํ Matrix๋ก ์๊ฐํ๋ฉด ์ดํดํ ์ ์๋ค.
Attention Mechanism์ ํ์ฉ
- Output์ด Sequence ํํ๋ผ๋ฉด, Input์ด ์ด๋ค ํํ๋ ๊ฐ์, Input์ ํน์ ๋ถ๋ถ์ ๊ฐ์กฐํด์ ๋ณด๋ ํํ๋ก Attention Mechanism์ ์ฌ์ฉํ ์ ์๋ค.
- Long-Term Dependency๋ฅผ ํด๊ฒฐํ ๋, ํน์ ๋ฅ๋ฌ๋ ๋ชจ๋ธ์ ํ์ต ์ํฉ์ ์ ๊ฒํ ๋ ํ์ฉ ๊ฐ๋ฅํ๋ค.
- Transformer, BERT ๋ฑ์์ CNN/RNN ์ฌ์ฉ ์์ด Attention๋ง์ผ๋ก ๋ชจ๋ธ์ด ๊ตฌ์ฑ๋์ด ์๋ค.
RNN์ Long-Term Dependency ํด๊ฒฐ
- Long-Term Dependency(์ฅ๊ธฐ ์์กด์ฑ)
- ์ ๊ณต๋ ๋ฐ์ดํฐ์ ๋ฐฐ์์ผ ํ ์ ๋ณด์ ์ ๋ ฅ ์ฐจ์ด(Gap)๊ฐ ํฐ ๊ฒฝ์ฐ ๋ ์ ๋ณด์ ๋ฌธ๋งฅ์ ์ฐ๊ฒฐํ๊ธฐ ์ด๋ ค์ด ํ์.
- LSTM(Long Short Term Memory Network) ๋ฑ์ ํ์ฉํ์ฌ ํด๊ฒฐํ ์ ์๋ค.
- Attention Mechanism์ ์ด์ฉํ๋ฉด Sequence๊ฐ ๊ธธ๋๋ผ๋ ๊ทธ ์ค์์ ์ค์ํ ๋ฒกํฐ์ ์ง์คํ ์ ์์ผ๋ฏ๋ก, Long-Term Dependency๋ฅผ ํด๊ฒฐํ ์ ์๋ค.
- Seq2Seq๋ฅผ ์ด์ฉํ ์์ด-๋ ์ผ์ด๊ฐ ๋ฒ์ญ ๋ชจ๋ธ์ ๊ฐ์ง๊ณ ์คํ์ ํ ๊ฒฐ๊ณผ, Attention Model์ ์ ์ฉํ๋ฉด ์ฑ๋ฅ์ด ์ํญ ์์นํ ๊ฒฐ๊ณผ๋ ์๋ค.
๋ฅ๋ฌ๋ ๋ชจ๋ธ์ด ์ ํ์ต๋์๋์ง ํ์ธํ ๋
- ๋ฅ๋ฌ๋ ๋ชจ๋ธ์ด ์ ๋๋ก ํ์ต๋์ง ์์์ผ๋ฉด, ์์ ๊ฐ์ด ์ฐ๊ด๋ ๋ฒกํฐ ๊ฐ์ ์ค์๋๊ฐ ๋๊ฒ ๋์ค์ง ์์ ๊ฒ์ด๋ค. ์ด๋ฐ ํ์์ ํตํด ๋ฅ๋ฌ๋ ๋ชจ๋ธ์ด ์ ๋๋ก ํ์ต๋์๋์ง ํ์ธ ๊ฐ๋ฅํ๋ค.
- Output์์ Sequence Input ์ค์์ ์ด๋ ๊ฐ์ด ์ค์ํ๊ฒ ์ฌ์ฉ๋์๋์ง ์ ์ ์๋ค.
- ์ ์ด๋ฏธ์ง์์ ๋ณด๋ฉด, ๋ ๋จ์ด์ ๋ป์ด ๊ฐ์ ๋ถ๋ถ์ด ํ์ฑํ๊ฐ ๋์ด ์๋ ๊ฒ์ ํ์ธํ ์ ์๋ค.
- ์๋ฅผ ๋ค์ด ๋ถ์ด accord์ ์์ด agreement๋ ๊ฐ์ ๋ป์ด๋ฏ๋ก Attention Model์ Output์ด ํฐ ๊ฒ์ ํ์ธํ ์ ์๋ค.
- ์ ํ ์๊ด ์๋ ๋ถ๋ถ์ Attention์ด ๋๊ฒ ๋์ค๋ฉด ๋ชจ๋ธ์ด ์๋ชป ํ์ต๋์๋ค๊ณ ๋ณผ ์ ์์ ๊ฒ์ด๋ค.
- ์ด๋ฏธ์ง ์บก์ ๋ ๋ชจ๋ธ์ ๊ฒฝ์ฐ, ๊ฐ ๋จ์ด๋ณ๋ก ์ด๋ฏธ์ง์ ์ด๋ค ๋ถ๋ถ ๋๋ฌธ์ ๊ทธ ๋จ์ด๊ฐ ์์ฑ๋์๋์ง ํ์ธ ๊ฐ๋ฅํ๋ค.
์ฐธ๊ณ ์๋ฃ
๋ ผ๋ฌธ
- Natural Language Inference, Sentence representationand Attention Mechanism
- Attention Is All You Need