使用方法について

ここでReNomRLの簡単なチュートリアルを紹介する. 本チュートリアルではDQNを基準に説明する.

ReNom RLの特徴として, DQN, A3C等の複雑なアルゴリズムが既に実装されている. 強化学習を実装する上において, 主にユーザーに実施する操作として以下の3つがある.

1- 環境の準備

簡易的に学習するために, 環境モデルをBaseEnvに合わせた構造にする必要がある. ここでは既存のモデルを使った方法と, 0から実装する方法を紹介する.

1.1 既存のモデルを使った方法

Open AIを使った環境モデルを一つの関数として用意してある. 例えばCartPoleのモデルを利用する場合は, 以下のように呼び出すこともできる.

from renom_rl.environ.openai import Breakout
env = CartPole00()

1.2 0から実装する方法

オリジナルの環境を作成する場合は, BaseEnvを継承し, 以下の変数および関数を書き換える必要がある.

  • action_shape: actionの形状

  • state_shape: stateの形状

  • 環境リセット時に実行. 初期状態を返す.

  • step(): action 取った時に実行. 次のstate, reward, terminal を返す.

  • sample(): ランダムな行動を選択する時に返す. action のサンプル結果を返す. (DQN, DDQN等)

例えば CustomEnv() というオブジェクトを新たに作る場合, CustomEnv は以下のようになる.

class CustomEnv(BaseEnv):

    def __init__(self, env):
        self.action_shape = (2,)
        self.state_shape = (4,)

        self.env = env
        self.step_continue = 0
        self.reward = 0



    def reset(self):
        return self.env.reset()


    def sample(self):
        rand = self.env.action_space.sample()
        return rand

    def step(self, action):
        state, _, terminal, _ = self.env.step(int(action))

        self.step_continue += 1
        reward = 0

        if terminal:
            if self.step_continue >= 200:
                reward = 1
            else:
                reward = -1

        self.reward=reward

        return state, reward, terminal

new_env = CustomEnv()

2- モデルの準備

ここではネットワークを定義する. ネットワークの構造は問題によって, 変わる. DQNの場合は以下のように定義する.

import renom as rm
q_network = rm.Sequential([rm.Dense(30, ignore_bias=True),
                           rm.Relu(),
                           rm.Dense(30, ignore_bias=True),
                           rm.Relu(),
                           rm.Dense(2, ignore_bias=True)])

3-強化学習の実装

環境とネットワークの構造を定義した後, DQNを実装する.

from renom_rl.discrete.dqn import DQN

algorithm = DQN(custom_env, q_network)

DQNのインスタンス生成後, 以下のコマンドで学習を実行する.

result = algorithm.fit()

各エポック終了後, テストを実行する.

学習で使用した環境は, 以下のように実行することで, テスト単体もできる.

result = algorithm.test()

このように実装することで, 定義したネットワークをDQNのアルゴリズムで学習することができる. 環境, その他の強化学習アルゴリズムについてはAPIページを参照いただけると幸いです.

How to Use - Detail - では, ReNomRL の使用方法について, より詳細に記述されている.