torchの関数について - Shinichi0713/LLM-fundamental-study GitHub Wiki

torch.einsum

torch.einsumは、PyTorchで**多次元テンソルの「インデックス記法による演算(Einstein summation)」**を行うための関数です。

インデックス記法による演算(Einstein summation、アインシュタインの縮約記法) テンソル(多次元配列)やベクトル・行列の演算を、添字(インデックス)を使って簡潔に表現する記法を意味する。 アインシュタインの縮約記法(Einstein summation convention)は、同じ添字が上下または左右に2回現れたとき、その添字について自動的に総和(sum)を取るというルールです。
これにより、数式を非常にコンパクトに書けるようになります。 アインシュタイン記法

C_{ik} = A_{ij} B_{jk} 

ここで、添字 ( j ) が両方の項に現れているので、「( j ) について総和を取る」という意味になります。

どんな処理?

  • 行列積や内積、転置、和など、複雑なテンソル演算を簡単な文字列で記述できる関数です。
  • NumPyのeinsumと同じ記法です。
  • 「どの次元をどのように計算するか」を文字列で指定します。

主な使い方

1. 内積(ドット積)

a = torch.randn(3)
b = torch.randn(3)
result = torch.einsum('i,i->', a, b)  # スカラー(内積)

2. 行列積

A = torch.randn(2, 3)
B = torch.randn(3, 4)
result = torch.einsum('ik,kj->ij', A, B)  # 通常の行列積

3. バッチごとの内積

A = torch.randn(5, 3, 4)
B = torch.randn(5, 4, 6)
result = torch.einsum('bij,bjk->bik', A, B)  # バッチごとの行列積

4. 和の計算

A = torch.randn(2, 3, 4)
result = torch.einsum('ijk->', A)  # テンソル全体の和

例:質問のコード

torch.einsum('bnw,bnw->bn', [q, k])
  • b, n, w という3次元のテンソル同士の各要素を掛けて、w方向に和を取る(内積)。

まとめ

  • torch.einsumは、複雑なテンソル演算を柔軟に一行で記述できる便利な関数です。
  • インデックス記法('ij,jk->ik'など)で演算内容を指定します。
  • 内積、行列積、バッチ演算、和、転置、アインシュタイン和など、幅広く使われます。

eval()モード / train()モード

1. 挙動が変わる層

a. BatchNorm(バッチ正規化)層

  • trainモード:
    • 入力ミニバッチの平均・分散を使って正規化し、その値でrunning mean/varを更新します。
  • evalモード:
    • 学習中に蓄積したrunning mean/varを使って正規化します(バッチ内の統計量は使いません)。
  • 推論時はevalモードにしないと、出力が不安定になります。

b. Dropout層

  • trainモード:
    • 指定した確率でニューロンをランダムに無効化(0に)します。
  • evalモード:
    • Dropoutを行わず、全てのニューロンをそのまま使います。
  • 推論時はDropoutがオフになるので、出力が安定します。

2. モデルのパラメータや勾配

  • どちらのモードでもパラメータは同じです。
  • eval()train()forwardの挙動を切り替えるだけで、パラメータや勾配計算(torch.no_grad()とは別)には直接関係ありません。

3. 使い分け

  • 学習時model.train()にします(デフォルトでこの状態です)。
  • 推論・評価時model.eval()にします。

まとめ

  • model.train():BatchNormやDropoutが「訓練モード」で動作
  • model.eval():BatchNormやDropoutが「推論モード」で動作
  • それ以外の層(Conv, Linearなど)は挙動は同じ

@torch.no_grad()の機能

@torch.no_grad()は、その関数内で実行されるすべての処理で自動微分(勾配計算)が無効になることを意味します。

具体的に何が起こるか

  • 勾配が記録されなくなる
    → その関数内での計算結果はrequires_grad=Trueのテンソルでも履歴が保存されず、バックプロパゲーション(誤差逆伝播)で使われません。
  • メモリ効率が良くなる
    → 勾配計算用の中間結果を保存しないので、メモリ消費が減ります。
  • 推論やパラメータ更新のない処理に最適
    → 例えば、モデルの推論、重みの手動更新、バッファの更新(例:queueの更新)などに使います。

@torch.no_grad()
def test(model, data):
    output = model(data)
    # ここでは勾配計算が行われない
    return output
  • 上記のような関数では、計算グラフが作られないため、loss.backward()などを呼んでもエラーになります。

まとめ

  • @torch.no_grad()が付いた関数では勾配計算が無効化される
  • 「推論」や「重みの手動更新」など、学習とは直接関係ない処理で使うのが一般的。