Python scikit-learn - アヤメの品種分類をクロスバリデーションで行う

アヤメの品種分類をクロスバリデーションで行います。

クロスバリデーションとは、最初に全てのデータを訓練データとテストデータに分割して、訓練データを用いて学習を行い、テストデータを用いて学習の妥当性を検証する手法です。

クロスバリデーションにはいろいろな手法がありますが、今回はK分割交差法をご紹介します。

【例 集合XをA,B,Cと3分割する場合】

(1) 集合Xを、AとBとCに分割します。
(2) Aとテストデータ、残りのB,Cを訓練データとして分類精度s1を求めます。
(3) Bとテストデータ、残りのA,Cを訓練データとして分類精度s1を求めます。
(4) Cとテストデータ、残りのA,Bを訓練データとして分類精度s1を求めます。
(5) 分類精度s1,s2,s3の平均を求め分類精度とします。

クロスバリデーション

実行するコードは下記の通りです。

[コード]

1
2
3
4
5
6
7
8
9
10
11
12
13
import pandas as pd
from sklearn import svm, metrics, model_selection, datasets

# アヤメのCSVデータを読み込む
iris = datasets.load_iris()

# データの学習
clf = svm.SVC()
scores = model_selection.cross_val_score(clf, iris.data, iris.target, cv=5)

# 正答率を求める
print('各正解率:', scores)
print('正解率:', scores.mean())

9行目のcv=5で分割数を5に設定しています。

model_selection.cross_val_score関数1つで、複数回の検証を行えるのは大変便利です。


実行結果は次のようになります。

[実行結果]

各正解率: [0.96666667 0.96666667 0.96666667 0.93333333 1.        ]
正解率: 0.9666666666666666

正答率96%以上と十分な結果となります。