基於sciket-learn實現SVM與核函式
阿新 • • 發佈:2018-11-22
支撐向量機(SVM)既可以用來解決分類問題,也可以解決迴歸問題,較多應用於解決分類問題,SVM嘗試尋找一個最優的角色邊界,距離兩個類別最近的樣本最遠,擁有較好的泛化能力。
下面從程式碼的角度一步步的來理解SVM
先引入常用類庫,匯入鳶尾花資料集,取兩個特徵
import numpy as np import matplotlib.pyplot as plt from sklearn import datasets iris = datasets.load_iris() X = iris.data y = iris.target X = X[y<2,:2] y = y[y<2]
視覺化資料
plt.scatter(X[y==0,0], X[y==0,1], color='red')
plt.scatter(X[y==1,0], X[y==1,1], color='blue')
plt.show()
SVM和kNN一樣,在使用資料的時候,先進行資料標準化處理
from sklearn.preprocessing import StandardScaler
standardScaler = StandardScaler()
standardScaler.fit(X)
X_standard = standardScaler.transform(X)
然後匯入SVC構造器
from sklearn.svm import LinearSVC
svc = LinearSVC(C=1e9)
svc.fit(X_standard, y)
新增視覺化函式
def plot_decision_boundary(model, axis): x0, x1 = np.meshgrid( np.linspace(axis[0], axis[1], int((axis[1]-axis[0])*100)).reshape(-1, 1), np.linspace(axis[2], axis[3], int((axis[3]-axis[2])*100)).reshape(-1, 1), ) X_new = np.c_[x0.ravel(), x1.ravel()] y_predict = model.predict(X_new) zz = y_predict.reshape(x0.shape) from matplotlib.colors import ListedColormap custom_cmap = ListedColormap(['#EF9A9A','#FFF59D','#90CAF9']) plt.contourf(x0, x1, zz, linewidth=5, cmap=custom_cmap)
視覺化
plot_decision_boundary(svc, axis=[-3, 3, -3, 3])
plt.scatter(X_standard[y==0,0], X_standard[y==0,1])
plt.scatter(X_standard[y==1,0], X_standard[y==1,1])
plt.show()
關於核函式
核函式更像是一種數學技巧,把核函式應用在公式裡,避免了將樣本先進行變形,然後在將變形的結果進行點乘的步驟,使用核函式,可以節省空間,不用儲存變形後的多維資料。核函式並不是SVM的專屬,只要有類似與xi 點乘xj的項,就可以使用核函式。
接下來看看核函式的具體應用核表現
同樣還是先匯入常用類庫和資料集
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
X, y = datasets.make_moons(noise=0.15, random_state=666)
plt.scatter(X[y==0,0], X[y==0,1])
plt.scatter(X[y==1,0], X[y==1,1])
plt.show()
接下來還是同樣的處理方式,這裡有一個超引數gamma,gamma越大,越會擬合測試資料。
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.svm import SVC
def RBFKernelSVC(gamma):
return Pipeline([
("std_scaler", StandardScaler()),
("svc", SVC(kernel="rbf", gamma=gamma))
])
這裡gamma取1進行訓練
svc = RBFKernelSVC(gamma=1)
svc.fit(X, y)
資料視覺化
def plot_decision_boundary(model, axis):
x0, x1 = np.meshgrid(
np.linspace(axis[0], axis[1], int((axis[1]-axis[0])*100)).reshape(-1, 1),
np.linspace(axis[2], axis[3], int((axis[3]-axis[2])*100)).reshape(-1, 1),
)
X_new = np.c_[x0.ravel(), x1.ravel()]
y_predict = model.predict(X_new)
zz = y_predict.reshape(x0.shape)
from matplotlib.colors import ListedColormap
custom_cmap = ListedColormap(['#EF9A9A','#FFF59D','#90CAF9'])
plt.contourf(x0, x1, zz, linewidth=5, cmap=custom_cmap)
plot_decision_boundary(svc, axis=[-1.5, 2.5, -1.0, 1.5])
plt.scatter(X[y==0,0], X[y==0,1])
plt.scatter(X[y==1,0], X[y==1,1])
plt.show()
感興趣的同學可以調節gamma的值,來檢視擬合的程度。