Python OpenAI Gym - CartPole(棒たてゲーム)を試す② 強化学習編

Stable Baselinesという強化学習アルゴリズムを使ってCartPoleを実行します。

インストール

下記のコマンドを実行しStable Baselinesを準備します。

1
2
3
4
pip install stable-baselines[mpi]
pip install tensorflow==1.14.0
pip install pyqt5
pip install imageio

強化学習アルゴリズムを使ってCartPole実行

強化学習のモデルを作成し、100,000回学習を行ってからCartPoleを実行してみます。

[コード]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import gym
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines import PPO2

# 環境の作成
env = gym.make('CartPole-v0')
env = DummyVecEnv([lambda: env])

# モデルの作成
model = PPO2('MlpPolicy', env, verbose=1)

# モデルの学習
model.learn(total_timesteps=100000)

# モデルのテスト
state = env.reset()
for i in range(200):
# 環境の描画
env.render()

# モデルの推論
action, _ = model.predict(state, deterministic=True)

# 1ステップ実行
state, rewards, done, info = env.step(action)

# エピソード完了判定
if done:
break

# 環境のクローズ
env.close()

実行してみると、棒が倒れることなくうまくバランスをとっていることが確認できます。

実行結果