マルチヘッドアテンション - Shinichi0713/LLM-fundamental-study GitHub Wiki
マルチヘッドアテンションとは
マルチヘッドアテンションは、複数のアテンションヘッドを持ち、それぞれが異なるパラメータセットを使用してアテンションを計算します。
これにより、異なるヘッドが入力情報の異なる部分に注目し、複数の視点から情報を集約することができます。
性能
- 表現力の向上
複数のアテンションヘッドを使用することで、モデルはより豊かで多様な表現を学習することができます。
各ヘッドが異なる特徴を捉えるため、最終的な統合表現はより多面的で表現力豊かになります。 - 並列計算の効率性
マルチヘッドアテンションは、各ヘッドの計算を並列して行うことができるため、計算効率が高いです。
これにより、大規模なデータセットや複雑なモデルでも効率的に学習が可能です。 - 局所的およびグローバルな情報の同時処理
異なるヘッドが局所的な情報(例えば、隣接する単語)とグローバルな情報(例えば、文全体の構造)を同時に処理することができます。
これにより、モデルは文脈に応じた柔軟な情報処理が可能になります。 - 情報の分散表現
マルチヘッドアテンションは、情報を複数のサブスペースに分散して表現します。
これにより、モデルは高次元の特徴空間で情報をより効果的に表現し、過学習を防ぐ効果も期待できます。
処理
実装
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
def __init__(self, embed_dim, num_heads):
super(MultiHeadAttention, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)
def forward(self, query, key, value, mask=None):
batch_size = query.size(0)
# Linear projections
# ヘッドを分割して、各ヘッドに対して分散表現計算を行う
Q = self.q_proj(query).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
K = self.k_proj(key).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
V = self.v_proj(value).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
# Scaled dot-product attention
# ドット積アテンションを採用しているのは,行列計算テクニックを用いて高速化しやすい
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn = F.softmax(scores, dim=-1)
# Apply attention to values
context = torch.matmul(attn, V)
# アテンション計算結果を集約
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.embed_dim)
# Final linear projection
output = self.out_proj(context)
return output
# Example usage
batch_size = 32
seq_length = 10
embed_dim = 512
num_heads = 8
query = torch.randn(batch_size, seq_length, embed_dim)
key = torch.randn(batch_size, seq_length, embed_dim)
value = torch.randn(batch_size, seq_length, embed_dim)
multi_head_attn = MultiHeadAttention(embed_dim, num_heads)
output = multi_head_attn(query, key, value)
print(output.shape) # Should be (batch_size, seq_length, embed_dim)
AIが賢くなった理由
〇attention is all you need
入力の中の関係性を見るアテンション、出力の関係性を見るアテンションがあり、相互の関係性を見るアテンションがあった
LLMの歴史
- 入力を扱うモジュールがBERT
- その後、OpenAIはデコーダに注目→次々に文章を出すことが出来る→AIと人間がコミュニケーションとれるようになった
- 2022年以降は学校の試験のように質問をして、回答を作らせるのがGPT-3.5→AIの汎化性能が向上した→対話を上手に出来るようにしたのがchat-gpt