DQNのロス関数 - Shinichi0713/RL-study GitHub Wiki
はい、DQN(Deep Q-Network)におけるロス関数(損失関数)の取り方にはいくつかのバリエーションがあります。
基本形から発展形まで、代表的なものを解説します。
1. 標準的なDQNのロス関数(MSE)
最も基本的なDQNのロス関数は、Q学習のベルマン方程式に基づく**平均二乗誤差(MSE)**です。
\text{Loss} = \mathbb{E}_{(s,a,r,s')} \left[ \left( Q(s, a) - y \right)^2 \right]
ここで、
- ( Q(s, a) ):現在のQネットワークによる状態(s)・行動(a)の値
- ( y = r + \gamma \max_{a'} Q_{\text{target}}(s', a') ):ターゲット値
- ( Q_{\text{target}} ):ターゲットネットワーク
- ( \gamma ):割引率
2. Huber Loss(スムーズL1損失)
MSEは外れ値に弱いので、**Huber損失(smooth L1 loss)**を使うことも多いです。
PyTorchではtorch.nn.SmoothL1Loss
で実装できます。
\text{Huber}(x) =
\begin{cases}
\frac{1}{2}x^2 & \text{if } |x| < \delta \\
\delta (|x| - \frac{1}{2}\delta) & \text{otherwise}
\end{cases}
ここで $x = Q(s,a) - y$。
3. Double DQNのロス関数
Double DQNではターゲット値の計算方法が異なりますが、ロス関数自体は上記と同じ形です。
ターゲット値が
y = r + \gamma Q_{\text{target}}(s', \arg\max_{a'} Q(s', a'))
となります。
4. 多段階リターン(n-step DQN)
1ステップだけでなく、nステップ先までの報酬を使うn-step DQNもあります。
y = r_1 + \gamma r_2 + \cdots + \gamma^{n-1} r_n + \gamma^n \max_{a'} Q_{\text{target}}(s_{t+n}, a')
5. 分布型DQN(Distributional DQN)
Q値を確率分布として扱う分布型DQN(C51、QR-DQNなど)では、
クロスエントロピー損失やKLダイバージェンスが使われます。
6. PER(Prioritized Experience Replay)対応
PERを使う場合、サンプルごとに**重要度重み(importance sampling weight)**をかけて損失を計算します。
\text{Loss} = w_i \cdot (Q(s, a) - y)^2
まとめ
- 基本はMSE(平均二乗誤差)
- より安定させたい場合はHuber損失
- Double DQNやn-step DQNなど発展型もターゲット値の計算が異なるだけでロスの形自体は同じ
- 分布型DQNではクロスエントロピーなどを使う