1. 程式人生 > >sklearn的快速使用之六(決策樹分類)

sklearn的快速使用之六(決策樹分類)

print(__doc__)

import numpy as np
import matplotlib.pyplot as plt

from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier

# Parameters
n_classes = 3
plot_colors = "ryb"
plot_step = 0.02

# Load data
iris = load_iris()

print (iris.data)
print (iris.data[:, [0, 1]])
#print (iris.data[:, [1, 2]])
for pairidx, pair in enumerate([[0, 1], [0, 2], [0, 3],
                                [1, 2], [1, 3], [2, 3]]):
    #  [0,1,2,3,4]從四列資料中選取2個要素進行訓練
    X = iris.data[:, pair]  
   
    y = iris.target

    # Train
    clf = DecisionTreeClassifier().fit(X, y)

    # 2行3列排列圖片
    plt.subplot(2, 3, pairidx + 1)
    #第一列
    x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
    print (x_min,x_max)
    #第二列
    y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
    print (y_min,y_max)
    # 繪製網格  xx 分割數 × yy 分割數 =  x×y 維資料 
    xx, yy = np.meshgrid(np.arange(x_min, x_max, plot_step),
                         np.arange(y_min, y_max, plot_step))    
    print (xx)
    print (yy)
    #plt.tight_layout()進行自動控制,此方法不能夠很好的控制影象間的間隔
    plt.tight_layout(h_pad=0.5, w_pad=0.5, pad=2.5)
    print ("---------------")
    print (xx.ravel())
    print ("---------------")
    print (yy.ravel())
    '''
    把第一列花萼長度資料按h取等分,作為行,然後複製多行,得到xx網格矩陣
    把第二列花萼寬度資料按h取等分,作為列,然後複製多列,得到yy網格矩陣
    xx和yy矩陣都變成兩個一維陣列,然後到np.c_[] 函式組合成一個二維陣列
    '''
    Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
    print (Z)
  
    Z = Z.reshape(xx.shape)
    print (Z)
    #繪製等高線
    cs = plt.contourf(xx, yy, Z, cmap=plt.cm.RdYlBu)
    #橫座標  縱座標 
    plt.xlabel(iris.feature_names[pair[0]])
    plt.ylabel(iris.feature_names[pair[1]])

    # Plot the training points
    for i, color in zip(range(n_classes), plot_colors):
        idx = np.where(y == i)
        plt.scatter(X[idx, 0], X[idx, 1], c=color, label=iris.target_names[i],
                    cmap=plt.cm.RdYlBu, edgecolor='black', s=15)
plt.suptitle("Decision surface of a decision tree using paired features")
plt.legend(loc='lower right', borderpad=0, handletextpad=0)
plt.axis("tight")
plt.show()