Stable Baselines② - PPO2での学習

Stable Baselines の強化学習アルゴリズムの1つである PPO2 を使ってCartPole-v1を攻略してみます。

PPO2での学習

PPO2 は、マルチプロセッシングで訓練可能な強化学習アルゴリズムです。

各処理のポイントはコメントをご参照ください。

[コード]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import gym
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines import PPO2

# 環境の生成
env = gym.make('CartPole-v1')
env = DummyVecEnv([lambda: env]) # 環境をベクトル化環境でラップする

# モデルの作成
# 'MlpPolicy':方策
# env:ベクトル化環境
# verbose:ログの詳細表示(0:ログなし、1:訓練情報を表示、2:TensorFlowログを表示)
model = PPO2('MlpPolicy', env, verbose=1)

# モデルの学習
# total_timesteps:訓練ステップ数
model.learn(total_timesteps=128000)

# モデルのテスト
state = env.reset()
while True:
# 環境の描画
env.render()
# モデルの推論
# deterministic=Trueは「ある状態に対して行動が一意に決まるようにする」
# →パフォーマンスが向上する
action, _ = model.predict(state, deterministic=True)
# 1ステップ実行
state, rewards, done, info = env.step(action)
# エピソード完了
if done:
break
# 環境のクローズ
env.close()

実行すると学習が行われた後、ゲーム画面が表示され棒のバランスがとれていることが確認できます。
[実行結果]