1. 程式人生 > >基於sciket-learn實現SVM與核函式

基於sciket-learn實現SVM與核函式

支撐向量機(SVM)既可以用來解決分類問題,也可以解決迴歸問題,較多應用於解決分類問題,SVM嘗試尋找一個最優的角色邊界,距離兩個類別最近的樣本最遠,擁有較好的泛化能力。

下面從程式碼的角度一步步的來理解SVM

先引入常用類庫,匯入鳶尾花資料集,取兩個特徵

import numpy as np
import matplotlib.pyplot as plt

from sklearn import datasets

iris = datasets.load_iris()

X = iris.data
y = iris.target

X = X[y<2,:2]
y = y[y<2]

視覺化資料

plt.scatter(X[y==0,0], X[y==0,1], color='red')
plt.scatter(X[y==1,0], X[y==1,1], color='blue')
plt.show()

SVM和kNN一樣,在使用資料的時候,先進行資料標準化處理

from sklearn.preprocessing import StandardScaler

standardScaler = StandardScaler()
standardScaler.fit(X)
X_standard = standardScaler.transform(X)

然後匯入SVC構造器

from sklearn.svm import LinearSVC

svc = LinearSVC(C=1e9)
svc.fit(X_standard, y)

新增視覺化函式

def plot_decision_boundary(model, axis):
    
    x0, x1 = np.meshgrid(
        np.linspace(axis[0], axis[1], int((axis[1]-axis[0])*100)).reshape(-1, 1),
        np.linspace(axis[2], axis[3], int((axis[3]-axis[2])*100)).reshape(-1, 1),
    )
    X_new = np.c_[x0.ravel(), x1.ravel()]

    y_predict = model.predict(X_new)
    zz = y_predict.reshape(x0.shape)

    from matplotlib.colors import ListedColormap
    custom_cmap = ListedColormap(['#EF9A9A','#FFF59D','#90CAF9'])
    
    plt.contourf(x0, x1, zz, linewidth=5, cmap=custom_cmap)

視覺化

plot_decision_boundary(svc, axis=[-3, 3, -3, 3])
plt.scatter(X_standard[y==0,0], X_standard[y==0,1])
plt.scatter(X_standard[y==1,0], X_standard[y==1,1])
plt.show()

 關於核函式

核函式更像是一種數學技巧,把核函式應用在公式裡,避免了將樣本先進行變形,然後在將變形的結果進行點乘的步驟,使用核函式,可以節省空間,不用儲存變形後的多維資料。核函式並不是SVM的專屬,只要有類似與xi 點乘xj的項,就可以使用核函式。

接下來看看核函式的具體應用核表現

同樣還是先匯入常用類庫和資料集

import numpy as np
import matplotlib.pyplot as plt

from sklearn import datasets

X, y = datasets.make_moons(noise=0.15, random_state=666)

plt.scatter(X[y==0,0], X[y==0,1])
plt.scatter(X[y==1,0], X[y==1,1])
plt.show()

接下來還是同樣的處理方式,這裡有一個超引數gamma,gamma越大,越會擬合測試資料。

from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.svm import SVC

def RBFKernelSVC(gamma):
    return Pipeline([
        ("std_scaler", StandardScaler()),
        ("svc", SVC(kernel="rbf", gamma=gamma))
    ])

這裡gamma取1進行訓練

svc = RBFKernelSVC(gamma=1)
svc.fit(X, y)

資料視覺化

def plot_decision_boundary(model, axis):
    
    x0, x1 = np.meshgrid(
        np.linspace(axis[0], axis[1], int((axis[1]-axis[0])*100)).reshape(-1, 1),
        np.linspace(axis[2], axis[3], int((axis[3]-axis[2])*100)).reshape(-1, 1),
    )
    X_new = np.c_[x0.ravel(), x1.ravel()]

    y_predict = model.predict(X_new)
    zz = y_predict.reshape(x0.shape)

    from matplotlib.colors import ListedColormap
    custom_cmap = ListedColormap(['#EF9A9A','#FFF59D','#90CAF9'])
    
    plt.contourf(x0, x1, zz, linewidth=5, cmap=custom_cmap)

plot_decision_boundary(svc, axis=[-1.5, 2.5, -1.0, 1.5])
plt.scatter(X[y==0,0], X[y==0,1])
plt.scatter(X[y==1,0], X[y==1,1])
plt.show()

感興趣的同學可以調節gamma的值,來檢視擬合的程度。