説明可能なAI①(SHAP)

SHAP

SHAPは、学習済みモデルにおいて各説明変数が予測値にどのような影響を与えたかを貢献度として定義して算出するモデルです。

各データごとに結果を出力し、可視化することができます。

前準備

前準備として、ボストンデータセットを用いた回帰モデル(決定木)を作成し、予測結果を確認します。

[Google Colaboratory]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import pandas as pd
from sklearn.datasets import load_boston
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeRegressor

boston = load_boston()
df = pd.DataFrame(boston.data,columns=boston.feature_names)
df["MEDV"] = boston.target
X= df[boston.feature_names]
y = df[["MEDV"]]
X_train, X_test, y_train, y_test = train_test_split(X, y,test_size=0.3,random_state=0)

print(len(X_train))
display(X_train.head(1))
print(len(X_test))
display(X_test.head(1))

tree_reg = DecisionTreeRegressor(max_depth=3, random_state=0).fit(X_train,y_train)

[実行結果]

以上で、回帰系の決定木モデルが作成できました。

重要度

作成したモデルの説明変数ごとに重要度を表示します。

[Google Colaboratory]

1
2
3
4
5
6
7
8
9
10
11
import matplotlib.pyplot as plt
import numpy as np

features = X_train.columns
importances = tree_reg.feature_importances_
indices = np.argsort(importances)

plt.figure(figsize=(6,6))
plt.barh(range(len(indices)), importances[indices], color="b", align="center")
plt.yticks(range(len(indices)), features[indices])
plt.show()

feature_importances_を参照し、説明変数ごとの重要度を取得しています。(5行目)

[実行結果]

このモデルでは、RMLSTATなどの重要度が高く、予測に強く影響していることが確認できます。

予測値の算出

予測値を算出します。

[Google Colaboratory]

1
2
3
X_test_pred = X_test.copy()
X_test_pred["pred"] = np.round(tree_reg.predict(X_test), 2)
X_test_pred.describe()[["RM","LSTAT","CRIM","DIS","PTRATIO","pred"]]

テストデータで予測値を算出し、結果を説明変数とマージしています。(2行目)

重要度の高い説明変数の上位5項目と、予測値(pred)を表示しています。(3行目)

[実行結果]

RMは8.7付近が最大値となっており、predの最大は50になっています。

予測値の表示

最も重要度の高かった説明変数であるRMでソートして、結果を確認してみます。

[Google Colaboratory]

1
X_test_pred.sort_values("RM")

[実行結果]

predの最大値が50だったので、RMが高いほどpredも高く、RMが低いほどpredも低く出ている傾向が見られます。

2番目に重要度が高かったLSTATはその逆で、LSTATが高いとpredは低く、LSTATが低ければpredが高くなっているようです。

このようにfeature_importances_は、モデル作成時にどのような説明変数が重要であるかを知るために大局的な指標となります。

一方SHAPは、作成したモデルの各説変数がどのように予測に寄与してしるかを知るための局所的な指標となります。

次回はSHAPを実装して予測結果を確認していきます。