マルチヘッドアテンション - Shinichi0713/LLM-fundamental-study GitHub Wiki

マルチヘッドアテンションとは

マルチヘッドアテンションは、複数のアテンションヘッドを持ち、それぞれが異なるパラメータセットを使用してアテンションを計算します。
これにより、異なるヘッドが入力情報の異なる部分に注目し、複数の視点から情報を集約することができます。

性能

  1. 表現力の向上 複数のアテンションヘッドを使用することで、モデルはより豊かで多様な表現を学習することができます。
    各ヘッドが異なる特徴を捉えるため、最終的な統合表現はより多面的で表現力豊かになります。
  2. 並列計算の効率性 マルチヘッドアテンションは、各ヘッドの計算を並列して行うことができるため、計算効率が高いです。
    これにより、大規模なデータセットや複雑なモデルでも効率的に学習が可能です。
  3. 局所的およびグローバルな情報の同時処理 異なるヘッドが局所的な情報(例えば、隣接する単語)とグローバルな情報(例えば、文全体の構造)を同時に処理することができます。
    これにより、モデルは文脈に応じた柔軟な情報処理が可能になります。
  4. 情報の分散表現 マルチヘッドアテンションは、情報を複数のサブスペースに分散して表現します。
    これにより、モデルは高次元の特徴空間で情報をより効果的に表現し、過学習を防ぐ効果も期待できます。

処理

image

image

image

image

実装

import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"

        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)

        # Linear projections
        # ヘッドを分割して、各ヘッドに対して分散表現計算を行う
        Q = self.q_proj(query).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.k_proj(key).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.v_proj(value).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)

        # Scaled dot-product attention
        # ドット積アテンションを採用しているのは,行列計算テクニックを用いて高速化しやすい
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn = F.softmax(scores, dim=-1)

        # Apply attention to values
        context = torch.matmul(attn, V)
        # アテンション計算結果を集約
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.embed_dim)
        
        # Final linear projection
        output = self.out_proj(context)
        return output

# Example usage
batch_size = 32
seq_length = 10
embed_dim = 512
num_heads = 8

query = torch.randn(batch_size, seq_length, embed_dim)
key = torch.randn(batch_size, seq_length, embed_dim)
value = torch.randn(batch_size, seq_length, embed_dim)

multi_head_attn = MultiHeadAttention(embed_dim, num_heads)
output = multi_head_attn(query, key, value)

print(output.shape)  # Should be (batch_size, seq_length, embed_dim)

AIが賢くなった理由

〇attention is all you need
入力の中の関係性を見るアテンション、出力の関係性を見るアテンションがあり、相互の関係性を見るアテンションがあった

LLMの歴史

  • 入力を扱うモジュールがBERT
  • その後、OpenAIはデコーダに注目→次々に文章を出すことが出来る→AIと人間がコミュニケーションとれるようになった
  • 2022年以降は学校の試験のように質問をして、回答を作らせるのがGPT-3.5→AIの汎化性能が向上した→対話を上手に出来るようにしたのがchat-gpt
    image