DQNのロス関数 - Shinichi0713/RL-study GitHub Wiki

はい、DQN(Deep Q-Network)におけるロス関数(損失関数)の取り方にはいくつかのバリエーションがあります。
基本形から発展形まで、代表的なものを解説します。

1. 標準的なDQNのロス関数(MSE)

最も基本的なDQNのロス関数は、Q学習のベルマン方程式に基づく**平均二乗誤差(MSE)**です。

\text{Loss} = \mathbb{E}_{(s,a,r,s')} \left[ \left( Q(s, a) - y \right)^2 \right]

ここで、

  • ( Q(s, a) ):現在のQネットワークによる状態(s)・行動(a)の値
  • ( y = r + \gamma \max_{a'} Q_{\text{target}}(s', a') ):ターゲット値
    • ( Q_{\text{target}} ):ターゲットネットワーク
    • ( \gamma ):割引率

2. Huber Loss(スムーズL1損失)

MSEは外れ値に弱いので、**Huber損失(smooth L1 loss)**を使うことも多いです。
PyTorchではtorch.nn.SmoothL1Lossで実装できます。

\text{Huber}(x) = 
\begin{cases}
\frac{1}{2}x^2 & \text{if } |x| < \delta \\
\delta (|x| - \frac{1}{2}\delta) & \text{otherwise}
\end{cases}

ここで $x = Q(s,a) - y$。


3. Double DQNのロス関数

Double DQNではターゲット値の計算方法が異なりますが、ロス関数自体は上記と同じ形です。
ターゲット値が

y = r + \gamma Q_{\text{target}}(s', \arg\max_{a'} Q(s', a'))

となります。


4. 多段階リターン(n-step DQN)

1ステップだけでなく、nステップ先までの報酬を使うn-step DQNもあります。

y = r_1 + \gamma r_2 + \cdots + \gamma^{n-1} r_n + \gamma^n \max_{a'} Q_{\text{target}}(s_{t+n}, a')

5. 分布型DQN(Distributional DQN)

Q値を確率分布として扱う分布型DQN(C51、QR-DQNなど)では、
クロスエントロピー損失やKLダイバージェンスが使われます。


6. PER(Prioritized Experience Replay)対応

PERを使う場合、サンプルごとに**重要度重み(importance sampling weight)**をかけて損失を計算します。

\text{Loss} = w_i \cdot (Q(s, a) - y)^2

まとめ

  • 基本はMSE(平均二乗誤差)
  • より安定させたい場合はHuber損失
  • Double DQNやn-step DQNなど発展型もターゲット値の計算が異なるだけでロスの形自体は同じ
  • 分布型DQNではクロスエントロピーなどを使う