numpyで乱数生成・データ分類・分類したデータをグラフ表示

numpyで生成した乱数を、特定関数で分類し、その分類したデータをグラフに表示します。
まずは必要なモジュールをimportします。

1
2
3
# 基本モジュールの宣言
import numpy as np
import matplotlib.pyplot as plt

次に乱数を生成します。縦100横2の配列にreshapeしています。

1
2
3
4
5
# 乱数生成(縦100,横2の配列分の乱数を生成)
D = 100
N = 2
xdata = np.random.randn(D * N).reshape(D, N).astype(np.float32)
xdata

結果(一部略)
生成した配列をx,yを表した100個のデータとみなして散布図を表示します。

1
2
3
4
5
# 散布図を表示
# xdata[:,0] <= 1列目に縦に並んでいる数字の全て
# xdata[:,1] <= 2列目に縦に並んでいる数字の全て
plt.scatter(xdata[:,0], xdata[:,1])
plt.show()

結果
分類用の関数を定義します。ここでは単純にxを2乗した数値を返しています。

1
2
3
# 関数の定義
def f(x):
return x * x

乱数の配列に対して、定義した関数の上にくるか下にくるかを判定します。
上にくる場合は1を、下にくる場合は0が返ります。
(ループで1つずつ処理しなくても入力の配列に対して、結果の配列が返ってくるのがnumpyの便利なところです。)

1
2
3
# 条件にあてはまるものを探す
tdata = (xdata[:,1] > f(xdata[:,0])).astype(np.int32)
tdata

結果
結果がTrue(1)となる配列のインデックスをndata0に、False(0)となる配列のインデックスをndata1に格納します。

1
2
3
4
5
6
# 乱数のデータを2つのグループに分ける
# True(1)とFalse(0)の場所を調べる
ndata0 = np.where(tdata==0)
ndata1 = np.where(tdata==1)
print(ndata0)
print(ndata1)

結果
最後にグループ分け関数と分類されたデータを1つのグラフとして表示します。

1
2
3
4
5
6
7
8
# 2つの種類のデータを図に示す
x = np.linspace(-2.0, 2.0, D) # -2.0から2.0の範囲でD=100個の点を用意する
plt.plot(x, f(x))

plt.scatter(xdata[ndata0, 0], xdata[ndata0, 1], marker='x')
plt.scatter(xdata[ndata1, 0], xdata[ndata1, 1], marker='o')

plt.show()

結果

(Google Colaboratoryで動作確認しています。)