カスタムGym環境作成(2) - PPO(PPO2)で学習

前回カスタム環境として、簡単なマップを作りスタート地点からゴール地点まで移動する環境を作成しました。

今回はStable BaselinesPPO(PPO2)アルゴリズムを使って、強化学習を行いそのカスタム環境を効率よく攻略してみます。

強化学習

カスタム環境を読み込み、PPO(PPO2)アルゴリズムで学習を行います。

PPO(PPO2)は計算量を削除するように改良された学習法で、使いやすさと優れたパフォーマンスのバランスがとれているためOpenAIのデフォルトの強化学習アルゴリズムとなっています。

学習したモデルはsaveメソッドでファイルに保存します。

[ソース]

train3.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# 警告を非表示
import warnings
warnings.simplefilter('ignore')
import tensorflow as tf
tf.get_logger().setLevel("ERROR")

import gym
from env3 import MyEnv

from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines import PPO2

# 環境の生成
env = MyEnv()
env = DummyVecEnv([lambda: env]) # ベクトル化環境でラップ

# モデルの生成
model = PPO2('MlpPolicy', env, verbose=1) # MlpPolicy:入力が特徴量の場合に適した方策

# モデルの学習
model.learn(total_timesteps=12800*10) # 訓練ステップ数を指定

# モデルの保存
model.save('model3') # 学習済みモデルをファイル保存

[結果(1部略)]

--------------------------------------
| approxkl           | 0.00015537896 |(新しい方策から古い方策へのKullback-Leibler発散尺度)
| clipfrac           | 0.0           |(クリップ範囲ハイパーパラメータが使用される回数の割合)
| explained_variance | 0.00605       |(誤差の分散)
| fps                | 276           |(1秒あたりのフレーム数)
| n_updates          | 1             |(更新回数)
| policy_entropy     | 1.3861711     |(方策のエントロピー)
| policy_loss        | -0.0009599341 |(方策の損失)
| serial_timesteps   | 128           |(1つの環境でのタイプステップ数)
| time_elapsed       | 0             |(経過時間)
| total_timesteps    | 128           |(全環境でのタイムステップ数)
| value_loss         | 113.650955    |(価値関数更新時の平均損失)
--------------------------------------
--------------------------------------
| approxkl           | 0.00017739396 |
| clipfrac           | 0.0           |
| explained_variance | -0.00457      |
| fps                | 1085          |
| n_updates          | 2             |
| policy_entropy     | 1.3849432     |
| policy_loss        | 0.0005420535  |
| serial_timesteps   | 256           |
| time_elapsed       | 0.463         |
| total_timesteps    | 256           |
| value_loss         | 332.47357     |
--------------------------------------
---------------------------------------
| approxkl           | 4.4801734e-05  |
| clipfrac           | 0.0            |
| explained_variance | -0.00433       |
| fps                | 1105           |
| n_updates          | 3              |
| policy_entropy     | 1.3839808      |
| policy_loss        | -0.00094374514 |
| serial_timesteps   | 384            |
| time_elapsed       | 0.584          |
| total_timesteps    | 384            |
| value_loss         | 112.30096      |
---------------------------------------
         :
        (略)
         :
--------------------------------------
| approxkl           | 0.004464049   |
| clipfrac           | 0.005859375   |
| explained_variance | 0.995         |
| fps                | 1260          |
| n_updates          | 998           |
| policy_entropy     | 0.005367896   |
| policy_loss        | -0.0047663264 |
| serial_timesteps   | 127744        |
| time_elapsed       | 96.8          |
| total_timesteps    | 127744        |
| value_loss         | 0.0663212     |
--------------------------------------
---------------------------------------
| approxkl           | 1.402045e-09   |
| clipfrac           | 0.0            |
| explained_variance | 1              |
| fps                | 1259           |
| n_updates          | 999            |
| policy_entropy     | 0.005131081    |
| policy_loss        | -2.0012958e-05 |
| serial_timesteps   | 127872         |
| time_elapsed       | 96.9           |
| total_timesteps    | 127872         |
| value_loss         | 0.0038208982   |
---------------------------------------
--------------------------------------
| approxkl           | 3.59563e-10   |
| clipfrac           | 0.0           |
| explained_variance | 1             |
| fps                | 1294          |
| n_updates          | 1000          |
| policy_entropy     | 0.0052420935  |
| policy_loss        | -7.759663e-06 |
| serial_timesteps   | 128000        |
| time_elapsed       | 97            |
| total_timesteps    | 128000        |
| value_loss         | 0.0014512216  |
--------------------------------------

学習済みモデルがmodel3.zipというファイル名で保存されます。

学習したモデルを使って実行

学習したモデルを読み込み、実行します。

[ソース]

play3.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
# 警告を非表示
import warnings
warnings.simplefilter('ignore')
import tensorflow as tf
tf.get_logger().setLevel("ERROR")

import gym
#from env2 import MyEnv
from env3 import MyEnv

from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines import PPO2

# 環境の生成
env = MyEnv()
env = DummyVecEnv([lambda: env])

# モデルの読み込み
model = PPO2.load('model3')

# モデルのテスト
state = env.reset()
total_reward = 0
while True:
# 環境の描画
env.render()

# モデルの推論
action, _ = model.predict(state)

# 1ステップの実行
state, reward, done, info = env.step(action)
total_reward += reward
print('reward:', reward, 'total_reward', total_reward)
print('-----------')

print('')
# エピソード完了
if done:
# 環境の描画
print('total_reward:', total_reward)
break

[結果]

Loading a model without an environment, this model cannot be trained until it has a valid environment.
☆ 山 
 山G 
  山 
山   
reward: [-1.] total_reward [-1.]
-----------

S 山 
☆山G 
  山 
山   
reward: [-1.] total_reward [-2.]
-----------

S 山 
 山G 
☆ 山 
山   
reward: [-1.] total_reward [-3.]
-----------

S 山 
 山G 
 ☆山 
山   
reward: [-1.] total_reward [-4.]
-----------

S 山 
 山G 
  山 
山☆  
reward: [-1.] total_reward [-5.]
-----------

S 山 
 山G 
  山 
山 ☆ 
reward: [-1.] total_reward [-6.]
-----------

S 山 
 山G 
  山 
山  ☆
reward: [-1.] total_reward [-7.]
-----------

S 山 
 山G 
  山☆
山   
reward: [-1.] total_reward [-8.]
-----------

S 山 
 山G☆
  山 
山   
reward: [100.] total_reward [92.]
-----------

total_reward: [92.]

何回実行しても、最短でゴールに向かって進んでいっているのが分かります。きちんと学習できているようです。

適切な環境さえ用意すれば、いろいろな問題を確実に解いてしまう強化学習はやはりすごいですね。


次回は、学習中のログを出力し取得報酬がどのように変化しているのかをグラフ化して確認してみます。