1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
| import os, gym import datetime import gym_anytrading import matplotlib.pyplot as plt from gym_anytrading.envs import TradingEnv, ForexEnv, StocksEnv, Actions, Positions from gym_anytrading.datasets import FOREX_EURUSD_1H_ASK, STOCKS_GOOGL from stable_baselines.common.vec_env import DummyVecEnv from stable_baselines import PPO2 from stable_baselines import ACKTR from stable_baselines.bench import Monitor from stable_baselines.common import set_global_seeds
import numpy as np import matplotlib.pyplot as plt
def simulation(i, prm): global means # ログフォルダの生成 log_dir = './logs/' os.makedirs(log_dir, exist_ok=True) # 環境の生成 env = gym.make('forex-v0', frame_bound=(prm['start_idx'], prm['end_idx']), window_size = prm['window_size']) env = Monitor(env, log_dir, allow_early_resets=True) # シードの指定 env.seed(0) set_global_seeds(0) # ベクトル化環境の生成 env = DummyVecEnv([lambda: env]) # モデルの読み込み # model = PPO2.load('model{}'.format(i)) model = ACKTR.load('model{}'.format(i)) # モデルのテスト env = gym.make('forex-v0', frame_bound=(prm['start_idx'] + prm['move_idx'], prm['end_idx'] + prm['move_idx']), window_size = prm['window_size']) env.seed(0) state = env.reset() while True: # 行動の取得 action, _ = model.predict(state) # 0 or 1 # 1ステップ実行 state, reward, done, info = env.step(action) # エピソード完了 if done: print('info:', info, info['total_reward']) means.append(info['total_reward']) break # グラフのプロット plt.cla() env.render_all()
for move_idx in range(0, 1201, 50): labels = [] means = [] prm = {'window_size': 10, #window_size 参照すべき直前のデータ数 'start_idx' : 10, #start_idx 学習データの開始位置 'end_idx' : 110, #end_idx 学習データの終了位置 'move_idx' : move_idx} #学習データからの移動分。移動したものを検証データとする。 for i in range(50): labels.append('{}'.format(i)) simulation(3, prm)
x = np.arange(len(labels)) width = 0.35
fig, ax = plt.subplots()
rect = ax.bar(x, means, width) ax.set_xticks(x) ax.set_xticklabels(labels)
plt.savefig('trading{:03d}.png'.format(move_idx))
|