機器學習(四):通俗理解支援向量機SVM及程式碼實踐
阿新 • • 發佈:2021-02-15
[上一篇文章](https://mp.weixin.qq.com/s/cEbGM0_Lrt8elfubxSF9jg)我們介紹了使用邏輯迴歸來處理分類問題,本文我們講一個更強大的分類模型。本文依舊側重程式碼實踐,你會發現我們解決問題的手段越來越豐富,問題處理起來越來越簡單。
支援向量機(Support Vector Machine, SVM)是最受歡迎的機器學習模型之一。它特別適合處理中小型複雜資料集的分類任務。
# 一、什麼是支援向量機
SMV在眾多例項中尋找一個最優的決策邊界,這個邊界上的例項叫做支援向量,它們“支援”(支撐)分離開超平面,所以它叫支援向量機。
那麼我們如何保證我們得到的決策邊界是**最優**的呢?
![](https://img2020.cnblogs.com/blog/678094/202102/678094-20210215154028317-184832083.png)
如上圖,三條黑色直線都可以完美分割資料集。由此可知,我們僅用單一直線可以得到無數個解。那麼,其中怎樣的直線是最優的呢?
![](https://img2020.cnblogs.com/blog/678094/202102/678094-20210215154059512-1481536751.png)
如上圖,我們計算直線到分割例項的距離,使得我們的直線與資料集的**距離儘可能的遠**,那麼我們就可以得到唯一的解。最大化上圖虛線之間的距離就是我們的目標。而上圖中重點圈出的例項就叫做支援向量。
這就是支援向量機。
# 二、從程式碼中對映理論
## 2.1 匯入資料集
新增引用:
```
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
```
匯入資料集(大家不用在意這個域名):
```
df = pd.read_csv('https://blog.caiyongji.com/assets/mouse_viral_study.csv')
df.head()
```
| | Med_1_mL | Med_2_mL | Virus Present |
|---:|-----------:|-----------:|----------------:|
| 0 | 6.50823 | 8.58253 | 0 |
| 1 | 4.12612 | 3.07346 | 1 |
| 2 | 6.42787 | 6.36976 | 0 |
| 3 | 3.67295 | 4.90522 | 1 |
| 4 | 1.58032 | 2.44056 | 1 |
該資料集模擬了一項醫學研究,對感染病毒的小白鼠使用不同劑量的兩種藥物,觀察兩週後小白鼠是否感染病毒。
* **特徵**: 1. 藥物Med_1_mL 藥物Med_2_mL
* **標籤**:是否感染病毒(1感染/0不感染)
## 2.2 觀察資料
```
sns.scatterplot(x='Med_1_mL',y='Med_2_mL',hue='Virus Present',data=df)
```
我們用seaborn繪製兩種藥物在不同劑量特徵對應感染結果的散點圖。
![](https://img2020.cnblogs.com/blog/678094/202102/678094-20210215154109721-466787380.png)
```
sns.pairplot(df,hue='Virus Present')
```
我們通過pairplot方法繪製特徵兩兩之間的對應關係。
![](https://img2020.cnblogs.com/blog/678094/202102/678094-20210215154115641-379303725.png)
我們可以做出大概的判斷,當加大藥物劑量可使小白鼠避免被感染。
## 2.3 使用SVM訓練資料集
```
#SVC: Supprt Vector Classifier支援向量分類器
from sklearn.svm import SVC
#準備資料
y = df['Virus Present']
X = df.drop('Virus Present',axis=1)
#定義模型
model = SVC(kernel='linear', C=1000)
#訓練模型
model.fit(X, y)
# 繪製圖像
# 定義繪製SVM邊界方法
def plot_svm_boundary(model,X,y):
X = X.values
y = y.values
# Scatter Plot
plt.scatter(X[:, 0], X[:, 1], c=y, s=30,cmap='coolwarm')
# plot the decision function
ax = plt.gca()
xlim = ax.get_xlim()
ylim = ax.get_ylim()
# create grid to evaluate model
xx = np.linspace(xlim[0], xlim[1], 30)
yy = np.linspace(ylim[0], ylim[1], 30)
YY, XX = np.meshgrid(yy, xx)
xy = np.vstack([XX.ravel(), YY.ravel()]).T
Z = model.decision_function(xy).reshape(XX.shape)
# plot decision boundary and margins
ax.contour(XX, YY, Z, colors='k', levels=[-1, 0, 1], alpha=0.5,
linestyles=['--', '-', '--'])
# plot support vectors
ax.scatter(model.support_vectors_[:, 0], model.support_vectors_[:, 1], s=100,
linewidth=1, facecolors='none', edgecolors='k')
plt.show()
plot_svm_boundary(model,X,y)
```
![](https://img2020.cnblogs.com/blog/678094/202102/678094-20210215154126120-213298888.png)
我們匯入`sklearn`下的`SVC`(Supprt Vector Classifier)分類器,它是SVM的一種實現。
## 2.4 SVC引數C
SVC方法引數`C`代表L2正則化引數,正則化的強度與`C`的值城**反比**,即C值越大正則化強度越弱,其必須嚴格為正。
```
model = SVC(kernel='linear', C=0.05)
model.fit(X, y)
plot_svm_boundary(model,X,y)
```
我們減少C的值,可以看到模型擬合數據的程度減弱。
![](https://img2020.cnblogs.com/blog/678094/202102/678094-20210215154133746-652684456.png)
## 2.5 核技巧
SVC方法的`kernel`引數可取值`{'linear', 'poly', 'rbf', 'sigmoid', 'precomputed'}`。像前文中所使用的那樣,我們可以使`kernel='linear'`進行線性分類。那麼如果我們像進行非線性分類呢?
### 2.5.1 多項式核心
**多項式核心**`kernel='poly'`的原理簡單來說就是,**用單一特徵生成多特徵來擬合曲線**。比如我們拓展X到y的對應關係如下:
| | X | X^2 | X^3 | y|
|---:|-----------:|-----------:|-----------:|----------------:|
| 0 | 6.50823 | 6.50823**2 | 6.50823**3 | 0 |
| 1 | 4.12612 | 4.12612**2 | 4.12612**3 | 1 |
| 2 | 6.42787 | 6.42787**2 | 6.42787**3 | 0 |
| 3 | 3.67295 | 3.67295**2 | 3.67295**3 | 1 |
| 4 | 1.58032 | 1.58032**2 | 1.58032**3 | 1 |
這樣我們就可以用曲線來擬合數據集。
```
model = SVC(kernel='poly', C=0.05,degree=5)
model.fit(X, y)
plot_svm_boundary(model,X,y)
```
我們使用多項式核心,並通過`degree=5`設定多項式的**最高次數**為5。我們可以看出分割出現了一定的弧度。
![](https://img2020.cnblogs.com/blog/678094/202102/678094-20210215154140818-1612550262.png)
### 2.5.2 高斯RBF核心
SVC方法預設核心為高斯`RBF`,即Radial Basis Function(徑向基函式)。這時我們需要引入`gamma`引數來控制鐘形函式的形狀。增加gamma值會使鐘形曲線變得更窄,因此每個例項影響的範圍變小,決策邊界更不規則。減小gamma值會使鐘形曲線變得更寬,因此每個例項的影響範圍變大,決策邊界更平坦。
```
model = SVC(kernel='rbf', C=1,gamma=0.01)
model.fit(X, y)
plot_svm_boundary(model,X,y)
```
![](https://img2020.cnblogs.com/blog/678094/202102/678094-20210215154146946-552797300.png)
## 2.6 調參技巧:網格搜尋
```
from sklearn.model_selection import GridSearchCV
svm = SVC()
param_grid = {'C':[0.01,0.1,1],'kernel':['rbf','poly','linear','sigmoid'],'gamma':[0.01,0.1,1]}
grid = GridSearchCV(svm,param_grid)
grid.fit(X,y)
print("grid.best_params_ = ",grid.best_params_,", grid.best_score_ =" ,grid.best_score_)
```
我們可以通過`GridSearchCV`方法來遍歷超引數的各種可能性來尋求最優超引數。這是通過算力碾壓的方式暴力調參的手段。當然,在分析問題階段,我們必須限定了各引數的可選範圍才能應用此方法。
因為資料集太簡單,我們在遍歷第一種可能性時就已經得到100%的準確率了,輸出如下:
```
grid.best_params_ = {'C': 0.01, 'gamma': 0.01, 'kernel': 'rbf'} , grid.best_score_ = 1.0
```
# 總結
當我們處理線性可分的資料集時,可以使用`SVC(kernel='linear')`方法來訓練資料,當然我們也可以使用更快的方法`LinearSVC`來訓練資料,特別是當訓練集特別大或特徵非常多的時候。
當我們處理非線性SVM分類時,可以使用高斯RBF核心,多項式核心,sigmoid核心來進行非線性模型的的擬合。當然我們也可以通過GridSearchCV尋找最優引數。
往期文章:
* [機器學習(三):理解邏輯迴歸及二分類、多分類程式碼實踐](https://mp.weixin.qq.com/s/cEbGM0_Lrt8elfubxSF9jg)
* [機器學習(二):理解線性迴歸與梯度下降並做簡單預測](https://mp.weixin.qq.com/s/_bHi-XH5ZXI4jDUzwH3xpQ)
* [機器學習(一):5分鐘理解機器學習並上手實踐](https://mp.weixin.qq.com/s/-KsbtgOc3C3ry-8P5f8K-Q)
* [前置機器學習(五):30分鐘掌握常用Matplotlib用法](https://mp.weixin.qq.com/s/5brLPnUP6sYvc-_JO7IzkA)
* [前置機器學習(四):一文掌握Pandas用法](https://mp.weixin.qq.com/s/LlLkkBfI-4s3qdVaiv7EdQ)
* [前置機器學習(三):30分鐘掌握常用NumPy用法](https://mp.weixin.qq.com/s/U8dV8ENzzSx_VwBDdJdr_w)
* [前置機器學習(二):30分鐘掌握常用Jupyter Notebook用法](https://mp.weixin.qq.com/s/PCGThwI-YD7_hHxO35V8xw)
* [前置機器學習(一):數學符號及希臘字母](https://mp.weixin.qq.com/s/BLxyqK3CGV9yd92