Python × AI - クラスタリング(MeanShift)

MeanShiftを使ったクラスタリングを行います。

MeanShiftクラスタ数が分からない場合に、データをクラスタを分類する方法です。

複数のガウス分布(正規分布)を仮定して、各データがどのガウス分布に所属するのかを決定し、クラスタ分析を行います。

ワインの分類データセット

まず、オープンデータであるワインの分類データセットを準備します。

[Google Colaboratory]

1
2
3
4
df_wine_all = pd.read_csv("https://archive.ics.uci.edu/ml/machine-learning-databases/wine/wine.data", header=None)
df_wine = df_wine_all[[0,10,13]]
df_wine.columns = [u"class", u"color", u"proline"]
pd.DataFrame(df_wine)

今回はワインの品種(0列目)、色(10列目)、プロリン量(13列目)を使用します。

[実行結果(一部略)]

データセットを可視化

抽出したワインのデータセットを可視化します。

[Google Colaboratory]

1
2
3
4
5
6
7
8
9
X = df_wine[["color","proline"]]
sc = preprocessing.StandardScaler()
X_norm = sc.fit_transform(X)
x = X_norm[:,0]
y = X_norm[:,1]
z = df_wine["class"]
plt.figure(figsize=(10,3))
plt.scatter(x,y, c=z)
plt.show

[実行結果]

3品種ごとに色分けされたデータが確認できます。

k-meansでクラスタリング

k-meansでクラスタリングを行います。

品種ごとに3つに分類したいのでクラスタ数は3に設定します。

[Google Colaboratory]

1
2
3
4
5
6
7
km = cluster.KMeans(n_clusters=3)
z_km = km.fit(X_norm)

plt.figure(figsize=(10,3))
plt.scatter(x,y, c=z_km.labels_)
plt.scatter(z_km.cluster_centers_[:,0],z_km.cluster_centers_[:,1],s=250, marker="*",c="red")
plt.show

[実行結果]

中心点から同心円状に広がって分類されていることが分かります。

MeanShiftでクラスタリング

MeanShift関数(1行目)を使って、MeanShiftでクラスタリングを行います。

[Google Colaboratory]

1
2
3
4
5
6
7
8
9
10
11
ms = cluster.MeanShift(seeds=X_norm)
ms.fit(X_norm)
labels = ms.labels_
cluster_centers = ms.cluster_centers_
print(cluster_centers)

plt.figure(figsize=(10,3))
plt.scatter(x,y, c=labels)
plt.plot(cluster_centers[0,0], cluster_centers[0,1], marker="*",c="red", markersize=14)
plt.plot(cluster_centers[1,0], cluster_centers[1,1], marker="*",c="red", markersize=14)
plt.show

[実行結果]

MeanShiftでは2つに分類されて実際の分類とはだいぶ違う結果となりました。

MeanShiftではクラスタ数を指定しなくてもクラスタリングが実施可能なので、パラメータはseeds(乱数シード)のみ設定しています。(1行目)

MeanShiftはk-meansをベースとして近いクラスタをまとめていき、既定の距離より近くなったクラスタをまとめて1つにします。