1. 程式人生 > >Python機器學習庫sklearn裡利用感知機進行三分類(多分類)的原理

Python機器學習庫sklearn裡利用感知機進行三分類(多分類)的原理

from IPython.display import Image  
%matplotlib inline  
# Added version check for recent scikit-learn 0.18 checks  
from distutils.version import LooseVersion as Version  
from sklearn import __version__ as sklearn_version  
  
from sklearn import datasets  
import numpy as np  
iris = datasets.load_iris() #http://scikit-learn.org/stable/auto_examples/datasets/plot_iris_dataset.html  
X = iris.data[:, [2, 3]]  
y = iris.target  #取species列,類別  
  
if Version(sklearn_version) < '0.18':  
    from sklearn.cross_validation import train_test_split  
else:  
    from sklearn.model_selection import train_test_split  
X_train, X_test, y_train, y_test = train_test_split(  
    X, y, test_size=0.3, random_state=0)  #train_test_split方法分割資料集  
  
from sklearn.preprocessing import StandardScaler  
sc = StandardScaler()   #初始化一個物件sc去對資料集作變換  
sc.fit(X_train)   #用物件去擬合數據集X_train,並且存下來擬合引數  
X_train_std = sc.transform(X_train)  
X_test_std = sc.transform(X_test)  

from sklearn.linear_model import Perceptron
#http://scikit-learn.org/stable/modules/generated/sklearn.linear_model.Perceptron.html#sklearn.linear_model.Perceptron
#ppn = Perceptron(n_iter=40, eta0=0.1, random_state=0)
ppn = Perceptron()  #y=w.x+b
ppn.fit(X_train_std, y_train)

#驗證perceptron的原理
def prelabmax(X_test_std):
    pym=[]
    for i in range(X_test_std.shape[0]):
        py=np.dot(ppn.coef_,X_test_std[i,:].T)+ppn.intercept_
        pym.append(max(py))
    return pym
prelabmax(X_test_std)   

def prelabindex(X_test_std,pym):
    index=[]
    for i in range(X_test_std.shape[0]):
        py=np.dot(ppn.coef_,X_test_std[i,:].T)+ppn.intercept_
        pymn=pym[i]
        for j in range(3):
            if py[j]==pymn:
                index.append(j)
    return np.array(index)
pym=prelabmax(X_test_std)
prelabindex(X_test_std,pym)
prelabindex(X_test_std,pym)==ppn.predict(X_test_std)
#Output:array([ True,  True,  True,  True,  True,  True,  True,  True,  True,
#               True,  True,  True,  True,  True,  True,  True,  True,  True,
#               True,  True,  True,  True,  True,  True,  True,  True,  True,
#               True,  True,  True,  True,  True,  True,  True,  True,  True,
#               True,  True,  True,  True,  True,  True,  True,  True,  True], dtype=bool)

即選擇y=wx+b值最大的項所在的組為其類別