torchの関数について - Shinichi0713/LLM-fundamental-study GitHub Wiki
torch.einsum
torch.einsum
は、PyTorchで**多次元テンソルの「インデックス記法による演算(Einstein summation)」**を行うための関数です。
インデックス記法による演算(Einstein summation、アインシュタインの縮約記法) テンソル(多次元配列)やベクトル・行列の演算を、添字(インデックス)を使って簡潔に表現する記法を意味する。 アインシュタインの縮約記法(Einstein summation convention)は、同じ添字が上下または左右に2回現れたとき、その添字について自動的に総和(sum)を取るというルールです。
これにより、数式を非常にコンパクトに書けるようになります。 アインシュタイン記法C_{ik} = A_{ij} B_{jk}
ここで、添字 ( j ) が両方の項に現れているので、「( j ) について総和を取る」という意味になります。
どんな処理?
- 行列積や内積、転置、和など、複雑なテンソル演算を簡単な文字列で記述できる関数です。
- NumPyの
einsum
と同じ記法です。 - 「どの次元をどのように計算するか」を文字列で指定します。
主な使い方
1. 内積(ドット積)
a = torch.randn(3)
b = torch.randn(3)
result = torch.einsum('i,i->', a, b) # スカラー(内積)
2. 行列積
A = torch.randn(2, 3)
B = torch.randn(3, 4)
result = torch.einsum('ik,kj->ij', A, B) # 通常の行列積
3. バッチごとの内積
A = torch.randn(5, 3, 4)
B = torch.randn(5, 4, 6)
result = torch.einsum('bij,bjk->bik', A, B) # バッチごとの行列積
4. 和の計算
A = torch.randn(2, 3, 4)
result = torch.einsum('ijk->', A) # テンソル全体の和
例:質問のコード
torch.einsum('bnw,bnw->bn', [q, k])
b, n, w
という3次元のテンソル同士の各要素を掛けて、w
方向に和を取る(内積)。
まとめ
torch.einsum
は、複雑なテンソル演算を柔軟に一行で記述できる便利な関数です。- インデックス記法('ij,jk->ik'など)で演算内容を指定します。
- 内積、行列積、バッチ演算、和、転置、アインシュタイン和など、幅広く使われます。
eval()
モード / train()
モード
1. 挙動が変わる層
a. BatchNorm(バッチ正規化)層
- trainモード:
- 入力ミニバッチの平均・分散を使って正規化し、その値でrunning mean/varを更新します。
- evalモード:
- 学習中に蓄積したrunning mean/varを使って正規化します(バッチ内の統計量は使いません)。
- →推論時はevalモードにしないと、出力が不安定になります。
b. Dropout層
- trainモード:
- 指定した確率でニューロンをランダムに無効化(0に)します。
- evalモード:
- Dropoutを行わず、全てのニューロンをそのまま使います。
- →推論時はDropoutがオフになるので、出力が安定します。
2. モデルのパラメータや勾配
- どちらのモードでもパラメータは同じです。
eval()
やtrain()
はforwardの挙動を切り替えるだけで、パラメータや勾配計算(torch.no_grad()
とは別)には直接関係ありません。
3. 使い分け
- 学習時は
model.train()
にします(デフォルトでこの状態です)。 - 推論・評価時は
model.eval()
にします。
まとめ
model.train()
:BatchNormやDropoutが「訓練モード」で動作model.eval()
:BatchNormやDropoutが「推論モード」で動作- それ以外の層(Conv, Linearなど)は挙動は同じ
@torch.no_grad()の機能
@torch.no_grad()
は、その関数内で実行されるすべての処理で自動微分(勾配計算)が無効になることを意味します。
具体的に何が起こるか
- 勾配が記録されなくなる
→ その関数内での計算結果はrequires_grad=True
のテンソルでも履歴が保存されず、バックプロパゲーション(誤差逆伝播)で使われません。 - メモリ効率が良くなる
→ 勾配計算用の中間結果を保存しないので、メモリ消費が減ります。 - 推論やパラメータ更新のない処理に最適
→ 例えば、モデルの推論、重みの手動更新、バッファの更新(例:queueの更新)などに使います。
例
@torch.no_grad()
def test(model, data):
output = model(data)
# ここでは勾配計算が行われない
return output
- 上記のような関数では、計算グラフが作られないため、
loss.backward()
などを呼んでもエラーになります。
まとめ
@torch.no_grad()
が付いた関数では勾配計算が無効化される。- 「推論」や「重みの手動更新」など、学習とは直接関係ない処理で使うのが一般的。