DeepMindのDQNアルゴリズムを再現してみた

こんにちは。データムスタジオの林です。プログラミングで一番嬉しい瞬間は、書いたコードがうまく動いた時だと思います。

さて、世界最強の囲碁棋士も敵わない、Google DeepMindが開発したAlphaGoは、ご存知のこととは思われますが、AlphaGoが世界的な成功を収める前に、同社が初めて世間の注目を集めた論文、皆さんご存知でしょうか?

それは2015年にNatureに掲載された「Human-level control through deep reinforcement learning」です。同論文と続編では、deep Q-network (DQN)という革新的な強化学習手法により、人工知能エージェントをAtari 2600ゲーム機の殆どのソフトで人間レベルのパフォーマンスを得られるように学習させたらしいです。

今回はこの論文が示したDQNアルゴリズムをPythonで再現し、そしてOpenAIによるgym環境のエージェントを自動的に学習させたいと思います。

では、早速始めましょう!

0. 強化学習の中心概念

強化学習の中心概念は「エージェント」「環境」「状態」「行動」「報酬」の五つとなります。

教師あり学習と違い、強化学習の学び方は与えられた訓練データを通じるではなく、「エージェント」が「環境」とインタラクトする中で積み重ねた経験より最適方策を練っていくのです。つもり、

ある環境におけるエージェントが状態\(s_{1}\)に対して行動\(a\)をし、その行動の結果として報酬\(r\)を得るのと同時に、環境の状態も行動\(a\)により状態\(s_{2}\)に変化していく

というステップの繰り返しです。こういう過去ステップを基に、長期的な報酬を最大化できる方策(各状態における最適行動)を学習するのが強化学習の目標です。

1. DQNの主要素

DQNアルゴリズムはQ-learningという強化学習手法の変種です。従来Q-learningとの違いが二つあります:

1.1 Q-Network

Q-learningアルゴリズムでは下記のような、Q-tableという行列で方策を表します:

ここで各行、列がそれぞれ状態、行動に対応しており、値が各状態において各行動を取った場合の期待報酬です。

しかし、Q-tableでは連続な状態空間或いは行動空間に対応出来ないため、DQNはディープランニングを導入し、有限個の報酬ではなく、ニューラルネットワークに\(Q:S\times A\to\mathbb{R}\)関数を近似する役割を与えています。あらゆる\((s, a)\)順序対に対して報酬を予測するこのネットワークはQ-networkと言います。

1.2 経験再生(experience replay)

更に、ゲームの流れが連続なので、各状態が前の状態と強く相関していることが多いです。ゆえに、Q-networkを学習時、確率的勾配降下法を類似しているサンプルにフィットさせると、局所的最小値に落ちてしまう恐れがあるらしいです。なので、ゲーム中続々発生したステップを順番に学習させるではなく、まずメモリに保存し、そしてランダムに抽出したサンプルにフィットさせるのです。この仕組みは経験再生と言います。

2. gym環境紹介

強化学習アルゴリズムをテストするのに環境が必要です。今回使うgymはOpenAlによる、強化学習アルゴリズム開発、テスト用のゲームのような環境です。gymをインストールするには

を実行します。

インストールが済んだらエージェントを動かせてみましょう。例えば、ランダムエージェントにCartPoleゲームを遊ばせるコードは:

CartPole環境の学習目標は、車を左右移動して上に棒を立てることです。上記ランダムエージェントコードを実行すると、このように動きます:

すぐ倒れてしまいました。さすがにランダム行動ではクリア出来るはずがないです。

ということで、強化学習でチャレンジしましょう!

3. DQNを実装

論文に記述された擬似コードは下記の通りです:

この記事では大体擬似コードのままで再現しますが、DeepMindの使った畳み込みニューラルネットワークの代わりに、普通のフィードフォワードニューラルネットワークを採用しています。その理由は二つあります:

  • gym環境において、状態はAtariゲームのようなピクセルではなく、実数の順序対で表されています。
    なので、画像認識の分野で優れた畳み込みニューラルネットワークを使う必要がありません。
  • 学習時間短縮のため。

3.1 Q-Network(ディープランニング)

まずはQ-networkを作成する関数を定義します。ここで_build_network()は初期ニューラルネットワークを作成します。そして_clone_network()の役割は一定周期でフィットしたQ-networkウェートで目標Q-network(アルゴリズムにおける\(\hat{Q}\))のウェートを上書きます。

3.2 最適行動

次はエージェントの行動方策を決める関数です。_get_optimal_action()は単純にQ-networkによる予測報酬が最大の行動を返します。_get_action()は\(\epsilon\)-greedy方策を実装し、ステップが経つにつれ減衰する確率\(\epsilon\)で行動を探索してから、徐々に最適行動に寄り付きます。

3.3 メモリ

こちらはエージェントのメモリに当たります。_remember()関数でエージェントに経験したステップを覚えさせることが出来ます。_init_memory()は学習開始時繰り返して_remember()を呼び出し、ランダムに初期メモリを作成します。

3.4 経験再生

いよいよ最も面白い部分に入りました。ここで_get_samples()はエージェントのメモリからサンプルを抽出し、サンプル別に即時報酬に最大遅延報酬を足します。それから、_get_samples()に基づいて、_replay()は抽出したサンプルと報酬にQ-networkをフィットさせます。

3.5 学習させる

では、全ての要素を組み合わせましょう:

DQNオブジェクトを作成し、learn()メソッドを呼び出すと、学習が始まります。

4. 学習結果検証

学習過程を観察するため、上記クラスにplot_training_scores()メソッドも加えました。

学習後にplot_training_scores()を呼び出すと、エージェントの各学習エピソードにおけるスコア(灰)と平均スコア(青)が示されます。

平均スコアが徐々に上がっていくのが、アルゴリズムがうまく動いていたことの証です。ちなみに、CartPole環境の最大ステップ数が500のため、最高スコアも500点です。

最後に、学習したQ-networkでCartPoleを遊ばせてみましょう:

CartPole-v1 score: 500.0

見事にバランスを保って、最高スコアを得られました。

5. まとめ

いかがでしょうか。今回はDeepMindのDQNアルゴリズムをPythonで再現し、gym環境にてテストしました。DQNでは見事に訓練データなしでクリア出来ました。さすがDeepMindと思います。

では、また次回!

6. 参考文献

Human-level control through deep reinforcement learning
https://deepmind.com/blog/deep-reinforcement-learning/

このページをシェアする: