atari - 模倣学習① 人間のデモ収集

ランダム行動では報酬を見つけにくい環境に対応するために模倣学習を試してみます。

Atari環境の1つであるボーリングゲーム(Bowling)を実行環境とします。

(Windowsではうまく動作しなかったので、Ubuntu 19.10で動作確認しています。)

インストール

下記のコマンドを実行し、実行環境をインストールします。

1
2
3
4
5
6
pip3 install gym
apt install cmake libopenmpi-dev python3-dev zlib1g-dev
pip3 install stable_baselines[mpi]
pip3 install tensorflow==1.14.0
pip3 install imageio
pip3 install baselines

人間のデモ収集

人間のデモ収集を行うコードは下記になります。

[コード]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import random
import pyglet
import gym
import time
from pyglet.window import key
from stable_baselines.gail import generate_expert_traj
from baselines.common.atari_wrappers import *

# 環境を作成
env = gym.make('BowlingNoFrameskip-v0')
env = MaxAndSkipEnv(env, skip=4) # 4フレームごとに行動を選択
env = WarpFrame(env) # 画面イメージを84x84のグレースケールに変換
env.render()

# キーイベント用のウィンドウ作成
win = pyglet.window.Window(width=300, height=100, vsync=False)
key_handler = pyglet.window.key.KeyStateHandler()
win.push_handlers(key_handler)
pyglet.app.platform_event_loop.start()

# キー状態の取得
def get_key_state():
key_state = set()
win.dispatch_events()
for key_code, pressed in key_handler.items():
if pressed:
key_state.add(key_code)
return key_state

# キー入力待ち
while len(get_key_state()) == 0:
time.sleep(1.0/30.0)

# 人間のデモを収集するコールバック
def human_expert(_state):
key_state = get_key_state() # キー状態の取得
action = 0 # 行動の選択

if key.SPACE in key_state:
action = 1
elif key.UP in key_state:
action = 2
elif key.DOWN in key_state:
action = 3

time.sleep(1.0/30.0) # スリープ
env.render() # 環境の描画
return action # 行動の選択

# 人間のデモの収集
generate_expert_traj(human_expert, 'bowling_demo', env, n_episodes=1)

デモ収集にはgenerate_expert_trajを使います。引数の意味は下記の通りです。

  • model(モデルまたはコールバック型)
    モデルまたはコールバック
  • save_path(str型)
    保存先のデモファイルのパス(拡張子なし)
  • env(gym.Env型)
    環境
  • n_timesteps(int型)
    モデルの学習ステップ数
  • n_episodes(int型)
    記録するエピソード数
  • image_folder(str型)
    画像を使用する場合の保存フォルダ

返値はデモ demo(dict型)となります。

実行

実行すると、次のような画面が表示されます。右側のウィンドウにフォーカスをあてるとゲームを操作することができます。

実行結果

updownで位置を選択し、fireでボールを投げます。
ボールを投げた後にupdownでボールの起動を曲げることができます。

10ゲーム(1エピソード)の人間の操作が収集され、bowling_demo.npzファイルrecorded_imagesフォルダが出力されます。

  • bowling_demo.npzファイル
    Pythonの辞書形式で保存されます。
    キーとしてactionsepisode_returnsrewardsobsepisode_startsがあり、obsには画像への相対パスが格納されます。
  • recorded_imagesフォルダ
    各状態の画像が保存されます。

次回は、今回収集した人間のデモデータを使って事前学習を行います。