ダブルDQNで強化学習 - Shinichi0713/RL-study GitHub Wiki

概要

ダブルDQNは、行動の選択と行動がどれくらい良いかの評価を別々のネットワークでも行うというもの。 従来の深層強化学習では、行動の選択と評価を同じネットワークで行っていたため行動の価値を実際よりも高く見積もってしまうという欠点があった。 ダブルDQNでは、この欠点を解消するため、2つのネットワークを使うように改良されている。

従来技術の課題

強化学習の分野において、近年大きな成果をあげた手法の一つに、「深い行動価値関数ネットワーク」、略して「深層価値関数ネットワーク」という手法があります。この手法は、人間の脳の神経回路網を模倣した「深層学習」と呼ばれる技術を使って、複雑な環境における学習を可能にしました。簡単に言うと、膨大な数の行動とその結果得られる報酬の関係性を、深層学習によって近似的に表現することで、エージェントは最適な行動を効率よく学習できるようになります。しかし、この画期的な手法にも弱点がありました。それは、行動の価値を本来よりも高く見積もってしまう傾向があることです。例えるなら、宝くじの当選確率を実際よりも高く見積もってしまうようなものです。この過大評価は、学習の効率を低下させる要因となります。

なぜ、高く見積もってしまう?

深層強化学習(Deep Reinforcement Learning, DRL)において「行動の価値(Q値)を本来よりも高く見積もってしまう傾向(overestimation bias)」が生じる主な原因は、 最大値の推定(max operator)と関数近似(ニューラルネットワーク等)による推定誤差が組み合わさることです。

詳細な説明

  1. 最大値演算子(max operator)によるバイアス
    Q学習(Q-learning)などでは、次状態での最大Q値を使って現在のQ値を更新します。
    しかし、Q値はニューラルネットワークなどの関数近似によって推定されており、**推定誤差(ノイズ)**が存在します。
    複数の行動のQ値推定値の中から最大値を選ぶと、たまたま推定値が高くなった(誤差の大きい)行動が選ばれやすくなります。
    これにより、本来の期待値以上にQ値が高く見積もられる現象が起こります。
  2. 関数近似の誤差
    ニューラルネットワークなどの関数近似器は、真のQ値ではなく推定値を出力します。
    推定値には誤差が含まれており、これが最大値選択時にバイアスを生みます。
  3. 数学的には…
    期待値の最大値(E[max(Q)])は、最大値の期待値(max(E[Q]))よりも大きくなります( Jensenの不等式)。
    これは「最大値の推定バイアス」と呼ばれます。

image

仕組み

ダブルDQNは、従来のディーキューエヌが抱えていた学習の過大評価、つまり実際よりも良いものと判断してしまう問題を解決するために開発されました。この過大評価は、同じ一つのネットワークで行動の選択と、その行動の良し悪しの評価を同時に行っていたことが原因でした。美味しいかどうかを自分で決めて、自分で作った料理を自分で食べるようなもので、どうしても甘くなってしまうのです。
この問題に対処するため、ダブルDQNは二つのネットワークを使うという工夫を凝らしました。ちょうど、料理を作る人と、その料理を味わって評価する人を分けるようなものです。一つ目のネットワークは行動の選択を担当します。様々な行動の中から、どれが一番良い結果に導くと考えられるかを選び出すのです。まるで、数ある食材の中から、今日の献立を決める料理人のようです。そして、二つ目のネットワークは、選ばれた行動の価値を評価します。まるで、料理人が作った料理を味わう、客観的な味覚を持つ審査員のようです。
行動の選択と価値の評価を別々のネットワークで行うことで、過大評価のリスクを減らすことができます。これは、自分の作った料理を自分で評価するのではなく、他の人に評価してもらうことで、公平な評価が得られるのと同じです。例えば、ゲームで敵を倒す行動が良いと判断した場合、その判断自体に間違いがなくても、報酬の予測値が実際よりも高くなってしまう可能性がありました。ダブルディーキューエヌでは、行動の選択は一つ目のネットワークで行いますが、その行動による報酬の予測は二つ目のネットワークで行います。これにより、過大な期待を抱くことなく、より正確な学習を進めることができるのです。このように、ダブルディーキューエヌは、二つのネットワークを巧みに使い分けることで、より安定した、信頼性の高い学習を実現しています。

この手法の良い点

  1. DQNの評価を過大化させる問題を抑制できる。

このメリットの効果

  1. ゲームプレイにおいては以前よりも高い得点を得られるようになった事例も報告されています。さらに、ダブルDQNはDQNを少し変更するだけで実装できるため、導入の容易さも大きな利点です。
  2. 既存のDQNを扱うコードに少し手を加えるだけで、簡単にダブルDQNへ移行できます。まさにDQNが抱えていた欠点を克服し、長所を伸ばした手法と言えるでしょう。ダブルDQNは、深層強化学習における新たな一歩と言える革新的な技術です。

ネットワークの役割

  1. メインネットワーク(Main Network) 役割: 現在の状態から最適な行動(アクション)を選択するためのQ値(行動価値)を予測します。 学習時の挙動: リプレイバッファからサンプルした遷移(状態・行動・報酬・次状態)を使って、損失関数を計算し、**重みを更新(学習)**します。 行動選択: 通常、$ϵ$-greedy などで行動を選ぶ際にもこのネットワークを使います。
  2. ターゲットネットワーク(Target Network) 役割: Q学習のターゲット(正解値)となるQ値を安定して推定するために使います。 学習時の挙動: ターゲットネットワークは基本的に重みの更新はしません。一定のステップごと、またはソフト更新(Polyak平均)で、メインネットワークの重みをコピーします。 ターゲット値の計算: DDQNでは、「次状態で最大のQ値」を計算する際に、 どの行動が最大か:メインネットワークで決める その行動のQ値はいくらか:ターゲットネットワークで出す という分担をします。

DDQNでのターゲット値の計算
通常のDQNのターゲットQ値の計算

y=r+γ max_{a'} Q_{target}(s′, a)

DDQNでのターゲットQ値の計算

\begin{align}
a^{*} = argmax_{a'} Q_{main}(s', a') \\
y = r + \gamma Q_{target}(s', a^{*}) \\
\end{align}

ターゲットネットワークが「重みの更新をしなくても良い理由」

  1. 目的は「安定したターゲット値」の提供 DDQNやDQNでは、Q値のターゲット(教師信号)として「ターゲットネットワーク」を使います。 もしターゲットネットワークもメインネットワークと同じタイミング・頻度で重みを更新してしまうと、 「ターゲット値」も「予測値」も同時に大きく変化してしまい、 学習が不安定になりやすいです(発散や収束しないなど)。

  2. ターゲット値の「自己参照的な更新」を防ぐため Q学習では

y=r+γ max_{a'} Q_{target}(s′, a)

同じネットワークで両方のQ値を算出すると、誤った学習方向に引っ張られる。

  1. 「遅れた重み」で安定性を確保 ターゲットネットワークは「メインネットワークの重みを一定間隔でコピー」するだけです。 これにより、ターゲット値はしばらくの間「固定」されます。 この「遅れ」が安定した学習をもたらします。

pythonでの実装

# 行動選択
action = main_network(state).argmax().item()

# ターゲット値計算
next_action = main_network(next_state).argmax().item()
target_q = target_network(next_state)[next_action].item()
target = reward + gamma * target_q

DDQNとDQNのネットワークは共通?

はい、DQNとDDQNで使うニューラルネットワークの構造(アーキテクチャ)は同じものでも大丈夫です


詳細

  • **DQN(Deep Q-Network)DDQN(Double DQN)**は、どちらも「状態を入力し、各行動のQ値を出力する」ニューラルネットワークを用います。
  • ネットワークの構造(層の数、ユニット数、活性化関数など)や重みの初期値は、DQNとDDQNで共通で構いません。
  • 違いは「ターゲット値の計算方法」や「学習時のアルゴリズム」にあります。
    • DQN:ターゲット値に max Q_target(s', a') を使う
    • DDQN:ターゲット値に Q_target(s', argmax Q_online(s', a')) を使う

まとめ

  • ネットワーク構造は共通でOK
  • アルゴリズム(ターゲット値の計算方法)が異なるだけ

補足:
実装上も、DQN/Double DQNのどちらでも「オンラインネットワーク」と「ターゲットネットワーク」は同じアーキテクチャで作るのが一般的です。
(PyTorchやTensorFlowのコードでも同じクラスを使い回します)