1. 程式人生 > >支援向量機(SVM)實現MNIST手寫體數字識別

支援向量機(SVM)實現MNIST手寫體數字識別

一、SVM演算法簡述

支援向量機即Support Vector Machine,簡稱SVM。一聽這個名字,就有眩暈的感覺。支援(Support)、向量(Vector)、機器(Machine),這三個毫無關聯的詞,硬生生地湊在了一起。從修辭的角度,這個合成詞最終落腳到”Machine”上,還以為是一種牛X的機器呢?實際上,它是一種演算法,是效果最好的分類演算法之一。 
SVM是最大間隔分類器,它能很好地處理線性可分的問題,並可推廣到非線性問題。實際使用的時候,還需要考慮噪音的問題。 
本文只是一篇學習筆記,主要參考了July、pluskid等人相關文章。將要點記錄下來,促進自己的進步。

SVM是最大間隔分類器

既然SVM是用來分類的,咱就舉個簡單的例子,看看這個SVM有啥特點。如下圖所示,有一個二維平面,平面上有兩種不同的資料,分別用圈和叉表示。由於這些資料是線性可分的,可以用一條直線將這兩個資料分開,這樣的直線可以有無數條。 
大間隔分類器 
綠線、粉紅線、黑線都能將兩類區分開。但是那種更好呢?感覺上黑線似乎更好些。粉紅線和綠線都離樣本太近。要是樣本或分界線稍稍有些擾動,分類就可能出錯。黑線好就好在離兩類都有一個安全間隔(藍線與黑線間的間隔),即使有些擾動,分類還是準確的。這個安全間隔,也就是“Margin”,當然我們覺得間隔越大分類越準確。 
這種分類思想該作何理解呢,他和邏輯迴歸的分類有何區別呢? 
當用邏輯迴歸的思想來處理分類問題時(將資料分成正負兩類:正類y=1,負類y=0)。邏輯迴歸函式反映的是資料是正類的概率,當這個概率大於0.5時,預測這個資料是正類,反之,小於0.5時,預測這個資料是負類。它優化的目標是預測出錯的概率越小越好。

可以參看這裡 
SVM則不同,它要找出一條離兩類都有一定安全間隔的分界線(專業點叫超平面)。優化的目標就是安全間隔越大越好。 
因此,SVM也被叫做最大間隔分類器。

線性可分的情況

SVM是通過間隔來分類。我們怎麼來定量地表達呢?先來看看線性可分的情況,分類函式

f(x)=wTx+bf(x)=wTx+b


xx是特徵向量,ww是與特徵向量維數相同的向量,也叫權重向量,bb是一個實數,也叫偏置。當f(x)=0f(x)=0時,表達的就是SVM的分類邊界,也就是超平面。SVM分成的類y可以為1或-1(注意,與邏輯迴歸不同,不是1和0)。f(x)f(x)大於0的點對應y=1的資料,f(x)f(x)小於0的點對應y=-1的資料。那我們關心的間隔怎麼表達? 
先來看看函式間隔,用γ^γ^表示:γ^=y(wTx+b)=yf(x)γ^=y(wTx+b)=yf(x)。|f(x)||f(x)|值越大,也就是yf(x)yf(x)越大,資料點離超平面越遠,我們越能確信這個資料屬於哪一類別,這是最直觀的認識。 
那這個是不是就完美表達了我們想要的間隔呢?看看這種情況,固定超平面,當ww,bb同時乘以2,這個間隔就擴大了兩倍。那怎麼表達不受引數縮放的變化影響的間距呢?老老實實來畫個圖看看咯。 
這裡寫圖片描述

 
xx是超平面外的一點,它離超平面的距離是γγ,顯然ww是超平面的法向量,x0x0是xx在超平面的投影。則x=x0+γw||w||x=x0+γw||w||,其中||w||||w||是範數,用初等數學來理解就是向量的長度,也叫向量的模。因為在超平面上,f(x0)=0f(x0)=0,等式兩邊乘以wTwT,再加上一個bb,化簡可得γ=wTx+b||w||=f(x)||w||γ=wTx+b||w||=f(x)||w||。注意這個γγ是可正可負的,為了得到絕對值,乘以一個對應的類別y,即可得出幾何間隔(用γ~γ~表示)的定義: 

γ~=yf(x)∥w∥=γ^∥w∥γ~=yf(x)‖w‖=γ^‖w‖


這個γ~γ~是不受引數縮放影響的。於是,我們的SVM的目標函式就是

maxγ~maxγ~


,當然它得滿足一些條件,根據margin的含義

yi(wTxi+b)=γ^i≥γ^,i=1,…,nyi(wTxi+b)=γ^i≥γ^,i=1,…,n


其中γ^=γ~∥w∥γ^=γ~‖w‖.之前說過,即使超平面固定,γ~γ~的值也會隨著||w||||w||的變化而變化。由於我們的目標就是要確定超平面,因此可以將無關的變數固定下來,固定的方式有兩種:一是固定||w||||w||,當我們找到最優的γ~γ~時γ^γ^也就隨之而固定;二是反過來固定γ^γ^,此時||w||||w||也可以根據最優的γ~γ~得到。出於方便推導和優化的目的,我們選第二種,令γ^=1γ^=1,則我們的目標函式化為: 

max1∥w∥,s.t.,yi(wTxi+b)≥1,i=1,…,nmax1‖w‖,s.t.,yi(wTxi+b)≥1,i=1,…,n

 

支援向量作何理解

說了這麼多,也沒有說到Support vector(支援向量),仔細觀看下圖:支援向量 
有兩個支撐著中間的分界超平面的超平面,稱為gap。它們到分界超平面的距離相等。這兩個gap上必定會有一些資料點。如果沒有,我們就可以進一步擴大margin了,那就不是最大的margin了。這些經過gap的資料點,就是支援向量(Support Vector)(它們支援了中間的超平面)。很顯然,只有支援向量才決定超平面,其他的資料點不影響超平面的確定。 
這是一個十分優良的特性。假設有100萬個資料點,支援向量100個,我們實際上只需要用這100個支援向量進行計算!!!這將大大提高儲存和計算的效能。

線性SVM的求解

考慮目標函式:max1∥w∥,s.t.,yi(wTxi+b)≥1,i=1,…,nmax1‖w‖,s.t.,yi(wTxi+b)≥1,i=1,…,n 
由於求的1||w||1||w||最大值相當於求12∥w∥212‖w‖2的最小值,所以上述目標函式等價於: 

min12∥w∥2s.t.,yi(wTxi+b)≥1,i=1,…,nmin12‖w‖2s.t.,yi(wTxi+b)≥1,i=1,…,n


1/2是方便求導時約去。這時目標函式是二次的,約束條件是線性的,所以它是一個凸二次規劃問題。這個問題可以用現成的QP(Quadratic Programming)的優化包進行求解。但是這個問題還有些特殊的結構,可以通過Lagrange Duality變換到對偶變數的優化問題。通常求解對偶變數優化問題的方法比QP優化包高效得多,而且推導過程中,可以很方便地引出核函式。 
簡單地說,通過給每個約束條件加上一個拉格朗日乘子,我們可以將它們融和到目標函式裡去,拉格朗日函式如下: 

L(w,b,α)=12∥w∥2−∑i=1nαi(yi(wTxi+b)−1)L(w,b,α)=12‖w‖2−∑i=1nαi(yi(wTxi+b)−1)

 

這裡還需要說明一點。當xixi不是支援向量時,αi=0αi=0;當xixi是支援向量時,yi(wTxi+b)−1=0yi(wTxi+b)−1=0。這個其實很好理解,因為超平面由支援向量決定,非支援向量不會影響到引數w。 
這裡省略掉推導的過程,這個函式經過變換,並且滿足KKT條件。會得出如下結論: 

w=∑i=1nαiyixiw=∑i=1nαiyixi

 

∑i=1nαiy=0∑i=1nαiy=0

 

求解的問題可以變換為 

maxα∑i=1nαi−12∑i,j=1nαiαjyiyjxTixjs.t.αi≥0,i=1,…,nmaxα∑i=1nαi−12∑i,j=1nαiαjyiyjxiTxjs.t.αi≥0,i=1,…,n

 

上式可以通過SMO演算法求出拉格朗日乘子αα,進而求出ww,通過

b=−maxyi=−1wTxi+minyj=1wTxj2b=−maxyi=−1wTxi+minyj=1wTxj2


,求出b

 

處理非線性問題

通過上面的討論,我們表達了SVM的目標函式,並給出了求解的方法。於是SVM就講完了,可以休息了?細心的讀者一定發現,上面是線上性可分的前提下展開討論的。線性不可分的時候怎麼辦? 
那可不可以將非線性問題轉換成線性問題呢?先來看個例子。 
兩個圓圈,非線性情況 
二維平面上,這是一個典型的線性不可分的問題。但我們增加一些特徵,將資料點對映到高維空間,他就變成了線性可分的點集了。如下圖: 
高維線性可分

事實上,將任何線性不可分的點集對映到高維空間(甚至可以到無窮維空間),總能變成線性可分的情況。只不過維數越高,計算量越大。維數大到無窮的時候,就是一場災難了。 
現在我們還是從數學上梳理一下這個對映的過程。 
根據w=∑ni=1αiyixiw=∑i=1nαiyixi,分類函式可寫成: 

f(x)=(∑i=1nαiyixi)Tx+b=∑i=1nαiyixTix+b=∑i=1nαiyi〈xi,x〉+bf(x)=(∑i=1nαiyixi)Tx+b=∑i=1nαiyixiTx+b=∑i=1nαiyi〈xi,x〉+b

 

〈⋅〉〈·〉表示向量內積。這個形式的有趣之處在於,對新點x的預測,只需要計算它與訓練資料點的內積即可。因為所有非支援向量所對應的係數αα都是0,因此對於新點的內積計算實際上只要針對少量的“支援向量”而不是所有的訓練資料。 
經過對映,分類函式變成

f(x)=∑i=1nαiyi〈ϕ(xi),ϕ(x)〉+bf(x)=∑i=1nαiyi〈ϕ(xi),ϕ(x)〉+b


而αα可以通過求解如下問題得到: 

maxα∑i=1nαi−12∑i,j=1nαiαjyiyj〈ϕ(xi),ϕ(xj)〉s.t.αi≥0,i=1,…,nmaxα∑i=1nαi−12∑i,j=1nαiαjyiyj〈ϕ(xi),ϕ(xj)〉s.t.αi≥0,i=1,…,n


這樣,似乎是拿到非線性資料,就找一個適當的對映ϕϕ,把原來的資料對映到新空間中,再做線性SVM即可。不過這個適當的對映可不是好惹的。二維空間做對映,需要5個維度,三維空間做對映,需要19個維度,維度數目是爆炸性增長的。到了無窮維,根本無法計算。這個時候就需要核函數出馬了。 
觀察上式,對映只是一箇中間過程,我們實際需要的是計算內積。如果有一種方式可以在特徵空間中直接計算內積。就能很好地避免維數災難了,這樣直接計算的方法稱為核函式方法。 
核是一個函式κκ,對所有x1x1,x2x2,滿足

κ(x1,x2)=〈ϕ(x1),ϕ(x2)〉κ(x1,x2)=〈ϕ(x1),ϕ(x2)〉

,這裡ϕϕ是從xx到內積特徵空間FF的對映。

 

幾個常用的核函式

通常人們會從一些常用的核函式中選擇(根據問題和資料的不同,選擇不同的引數,實際上就是得到了不同的核函式),例如:

  • 高斯核κ(x1,x2)=exp(−|x1−x2|22σ2)κ(x1,x2)=exp⁡(−|x1−x2|22σ2),這個空間會將原始空間對映到無窮維空間。不過,如果σσ選得很大的話,高次特徵上的權重實際上衰減得非常快,所以實際上(數值上近似一下)相當於一個低維的空間;反過來,如果σσ選得很小的話,則可以將任意的資料對映為線性可分。當然,這不一定是好事,因為隨之而來的可能是非常嚴重的過擬合問題。不過,總的來說,通過調控引數,高斯核實際上具有相當的靈活性,也是使用最廣泛的核函式之一。下圖所示的例子便是把低維空間不可分資料通過高斯核函式對映到了高維空間:

高斯核函式

  • 多項式核

    κ(x1,x2)=(〈x1,x2〉+R)dκ(x1,x2)=(〈x1,x2〉+R)d

    ,這個核所對應的對映實際上是可以寫出來的,該空間的維度是 
    (m+dd)(m+dd),其中mm是原始空間的維度。
  • 線性核κ(x1,x2)=〈x1,x2〉κ(x1,x2)=〈x1,x2〉,這實際上就是原始空間中的內積。這個核存在的主要目的是使得“對映後空間中的問題”和“對映前空間中的問題”兩者在形式上統一起來了(意思是說,咱們有的時候,寫程式碼,或寫公式的時候,只要寫個模板或通用表示式,然後再代入不同的核,便可以了,於此,便在形式上統一了起來,不用再分別寫一個線性的,和一個非線性的)。

核函式的本質

總結一下核函式,實際是三點:

  • 實際中,當我們遇到線性不可分的樣例,常用做法是把樣例特徵對映到高維空間中
  • 但如果凡是遇到線性不可分的樣例,一律對映到高維空間,那麼這個維度大小是會高到可怕的
  • 此時,核函式就隆重登場了,核函式的價值在於它雖然也是將特徵進行從低維到高維的轉換,但核函式絕就絕在它事先在低維上進行計算,而將實質上的分類效果表現在了高維上,也就如上文所說的避免了直接在高維空間中的複雜計算。

處理噪音

回顧此前的介紹,SVM用來處理線性可分的問題。後來為了處理非線性資料,使用核函式將原始資料對映到高維空間,轉化為線性可分的問題。但是有時候,並不是資料本身是非線性結構的,而只是因為資料有噪音。對於這種偏離正常位置很遠的資料點,我們稱之為outlier。超平面本身就是隻有少數幾個支援向量組成,如果支援向量裡存在outlier,就會有嚴重影響。如下圖: 
噪音 
用黑圈圈起來的那個藍點就是一個outlier,它偏離了自己原本所應該的那個半空間,如果直接忽略掉,原本的分隔超平面還是挺好的,但是由於這個outlier的出現,導致分隔超平面不得不被擠歪了,變成黑色虛線所示,同時margin也相應變小了。更嚴重的是,如果outlier再往右上移動一些距離的話,將無法構造出能將資料分開的超平面來。 
為了處理這種情況,SVM允許資料點在一定程度上偏離一下超平面。上圖中,黑色實線所對應的距離,就是該outlier偏離的距離,如果把它移動回來,就剛好落在原來的超平面上,而不會使超平面發生變形了。具體來說,原來的約束條件變成: 

yi(wTxi+b)≥1−ξi,i=1,…,nyi(wTxi+b)≥1−ξi,i=1,…,n


其中ξi≥0ξi≥0稱為鬆弛變數,對應資料點xixi允許偏離的函式間隔的量。對於一般的資料(非支援向量,也非outlier),這個值就是0。如果ξiξi任意大的話,那任意的超平面都是符合要求的。所以,我們在原來的目標函式後面加上一項,使得這些ξiξi的總和也要最小:

min12||w||2+C∑i=1nξimin12||w||2+C∑i=1nξi


,其中C是一個引數,用於控制目標函式中兩項(尋找margin最大的超平面和保證資料點偏差最小)之間的權重。注意,ξiξi是需要優化的變數,而CC是一個事先確定好的常量。完整的目標函式是: 

min12||w||2+C∑i=1nξis.t.yi(wTxi+b)≥1−ξi,i=1,…,nmin12||w||2+C∑i=1nξis.t.yi(wTxi+b)≥1−ξi,i=1,…,n


通過拉格朗日對偶求解, 

w=∑i=1nαiyixiw=∑i=1nαiyixi

 

∑i=1nαiy=0∑i=1nαiy=0


求解的問題可以變換為 

maxα∑i=1nαi−12∑i,j=1nαiαjyiyjxTixjs.t.0≤αi≤C,i=1,…,nmaxα∑i=1nαi−12∑i,j=1nαiαjyiyjxiTxjs.t.0≤αi≤C,i=1,…,n


對比之前的結果,只不過是αα多了一個上限CC。

 

小結

  • SVM是一個最大間距分類器。
  • 線上性可分的情況下,它的目標函式是min12|w|2s.t.,yi(wTxi+b)≥1,i=1,…,nmin12|w|2s.t.,yi(wTxi+b)≥1,i=1,…,n,較好的求解方法是轉換為拉格朗日對偶問題,並用SMO演算法進行求解
  • 線上性不可分的情況下,其基本思想是,將低維線性不可分的問題對映為高維可分的問題。具體實現辦法是:利用核函式,在低維空間進行運算,而將實質上的分類效果表現在高維上。
  • 考慮到資料點中可能存在噪音的干擾,需要將目標函式中加入鬆弛變數而求解的思路和方法不變。

二、程式碼及結果

環境是python3.x+sklearn+pythcarm

# -*- coding: utf-8 -*-
# @Time    : 2018/8/23 10:38
# @Author  : Barry
# @File    : mnist_svm.py
# @Software: PyCharm Community Edition

import pickle
import gzip

# Third-party libraries
import numpy as np

def load_data():
    """
    返回包含訓練資料、驗證資料、測試資料的元組的模式識別資料
    訓練資料包含50,000張圖片,測試資料和驗證資料都只包含10,000張圖片
    """
    f = gzip.open('./MNIST_data/mnist.pkl.gz', 'rb')
    training_data, validation_data, test_data = pickle.load(f,encoding='bytes')
    f.close()
    return (training_data, validation_data, test_data)


# Third-party libraries
from sklearn import svm
import time

def svm_baseline():
    print (time.strftime('%Y-%m-%d %H:%M:%S') )
    training_data, validation_data, test_data = load_data()
    # 傳遞訓練模型的引數,這裡用預設的引數
    clf = svm.SVC(C=100.0, kernel='rbf', gamma=0.03)
    # clf = svm.SVC(C=8.0, kernel='rbf', gamma=0.00,cache_size=8000,probability=False)
    # 進行模型訓練
    clf.fit(training_data[0], training_data[1])
    # test
    # 測試集測試預測結果
    predictions = [int(a) for a in clf.predict(test_data[0])]
    num_correct = sum(int(a == y) for a, y in zip(predictions, test_data[1]))
    print ("%s of %s test values correct." % (num_correct, len(test_data[1])))
    print (time.strftime('%Y-%m-%d %H:%M:%S'))

if __name__ == "__main__":
    svm_baseline()

執行結果:

2018-08-23 13:33:32
9848 of 10000 test values correct.
2018-08-23 13:43:20

準確率大約98.48%

三、SVM識別MNIST演算法過程

SVM分類演算法以另一個角度來考慮問題。其思路是獲取大量的手寫數字,常稱作訓練樣本,然後開發出一個可以從這些訓練樣本中進行學習的系統。換言之,SVM使用樣本來自動推斷出識別手寫數字的規則。隨著樣本數量的增加,演算法可以學到更多關於手寫數字的知識,這樣就能夠提升自身的準確性。 
本文采用的資料集就是著名的“MNIST資料集”。這個資料集有60000個訓練樣本資料集和10000個測試用例。直接呼叫scikit-learn庫中的SVM,使用預設的引數,1000張手寫數字圖片,判斷準確的圖片就高達9435張。

通常,對於分類問題。我們會將資料集分成三部分,訓練集、測試集、交叉驗證集。用訓練集訓練生成模型,用測試集和交叉驗證集進行驗證模型的準確性。

需要說明的是,svm.SVC()函式的幾個重要引數。直接用help命令檢視一下文件,這裡我稍微翻譯了一下: 
C : 浮點型,可選 (預設=1.0)。誤差項的懲罰引數C 
kernel : 字元型, 可選 (預設=’rbf’)。指定核函式型別。只能是’linear’, ‘poly’, ‘rbf’, ‘sigmoid’, ‘precomputed’ 或者自定義的。如果沒有指定,預設使用’rbf’。如果使用自定義的核函式,需要預先計算核矩陣。 
degree : 整形, 可選 (預設=3)。用多項式核函式(‘poly’)時,多項式核函式的引數d,用其他核函式,這個引數可忽略 
gamma : 浮點型, 可選 (預設=0.0)。’rbf’, ‘poly’ and ‘sigmoid’核函式的係數。如果gamma是0,實際將使用特徵維度的倒數值進行運算。也就是說,如果特徵是100個維度,實際的gamma是1/100。 
coef0 : 浮點型, 可選 (預設=0.0)。核函式的獨立項,’poly’ 和’sigmoid’核時才有意義。 
可以適當調整一下SVM分類演算法,看看不同引數的結果。當我的引數選擇為C=100.0, kernel=’rbf’, gamma=0.03時,預測的準確度就已經高達98.5%了。

相同的C,gamma越大,分類邊界離樣本越近。相同的gamma,C越大,分類越嚴格。 
下圖是不同C和gamma下分類器交叉驗證準確率的熱力圖 
gamma和C 
由圖可知,模型對gamma引數是很敏感的。如果gamma太大,無論C取多大都不能阻止過擬合。當gamma很小,分類邊界很像線性的。取中間值時,好的模型的gamma和C大致分佈在對角線位置。還應該注意到,當gamma取中間值時,C取值可以是很大的。 
在實際專案中,這幾個引數按一定的步長,多試幾次,一般就能得到比較好的分類效果了。

小結

回顧一下整個問題。我們進行了如下操作。對資料集分成了三部分,訓練集、測試集和交叉驗證集。用SVM分類模型進行訓練,依據測試集和驗證集的預測結果來優化引數。依靠sklearn這個強大的機器學習庫,我們也能解決手寫識別這麼高大上的問題了。事實上,我們只用了幾行簡單程式碼,就讓測試集的預測準確率高達98.5%。 
事實上,就算是一般性的機器學習問題,我們也是有一些一般性的思路的,如下: 
這裡寫圖片描述