学習が進んだモデルでのみnanが発生 - Shinichi0713/Reinforce-Learning-Study GitHub Wiki
「学習が進んだモデルでのみnanが発生し、学習初期のモデルでは同じデータでもnanが発生しない」という現象には、以下のような原因が考えられます。
主な原因
1. 重みの発散・勾配爆発
- 学習が進むにつれてパラメータが極端な値になり、内部の計算(線形変換や活性化関数)が発散してしまい、最終的にnanが発生するケースが多いです。
2. 活性化関数による数値不安定
- 例えばReLUやtanh、sigmoidなどの活性化関数の入力が大きな値になると、計算結果がinfやnanになりやすくなります。
- 特にsoftmaxやlog、expなどは極端な値でnanを生みやすいです。
3. 損失関数の不安定化
- 学習が進むと出力分布が偏り、損失関数(例:クロスエントロピーやMSE)の計算で0除算やlog(0)が発生しやすくなります。
4. 正則化やバッチノルムの不安定化
- バッチノーマライゼーションやレイヤーノーマライゼーションの統計量が不安定になり、nanを生むことがあります。
5. 学習率が大きすぎる
- 特に後半でlossが急激に下がったり、学習率スケジューラで学習率が上がったりすると、重みが一気に発散することがあります。
なぜ「学習が進んだモデル」でnanが起きやすいのか?
- 重みが初期値のときは値が安定しているので、極端な値が出にくい。
- 学習が進むと重みが大きくなったり、偏ったりして数値計算が不安定になるため、nanが出やすくなります。
対策
- 学習率を下げる
→ 発散しにくくなります。 - 勾配クリッピングを導入する
→torch.nn.utils.clip_grad_norm_
などで勾配の大きさを制限。 - 重みの初期化や正則化の見直し
- 活性化関数や損失関数の入力値を監視
→ 中間出力やlossの値をprint/logして異常値を早期発見。 - バッチノルムの統計量を確認する
デバッグのヒント
- どの層・どの時点でnanが発生するかを特定するため、forwardの途中で
torch.isnan(tensor).any()
でチェックを入れると原因特定が早くなります。
まとめ
学習が進んだモデルでnanが発生するのは、重みの発散や数値不安定が主な原因です。
学習率の見直し・勾配クリッピング・中間値の監視などで対策しましょう。