1. 程式人生 > >支援向量機SVM通俗理解(python程式碼實現)

支援向量機SVM通俗理解(python程式碼實現)

這是第三次來“複習”SVM了,第一次是使用SVM包,呼叫包並嘗試調節引數。聽聞了“流弊”SVM的演算法。第二次學習理論,看了李航的《統計學習方法》以及網上的部落格。看完後感覺,滿滿的公式。。。記不住啊。第三次,也就是這次通過python程式碼手動來實現SVM,才讓我突然對SVM不有畏懼感。希望這裡我能通過簡單粗暴的文字,能讓讀者理解到底什麼是SVM,這貨的演算法思想是怎麼樣的。看之前千萬不要畏懼,說到底就是個演算法,每天啃一點,總能啃完它,慢慢來還可以加深印象。
SVM是用來解決分類問題的,如果解決兩個變數的分類問題,可以理解成用一條直線把點給分開,完成分類。如下:
這裡寫圖片描述
上面這些點很明顯不一樣,我們從中間畫一條直線就可以用來分割這些點,但是什麼樣的直線才是最好的呢?通俗的說,就是一條直線“最能”分割這些點,也就是上圖中的直線。他是最好的一條直線,使所有的點都“儘量”遠離中間那條直線。總得的來說,SVM就是為了找出一條分割的效果最好的直線。怎麼樣找出這條直線,就變成了一個數學問題,通過數學一步一步的推導,最後轉化成程式。這裡舉例是二個特徵的分類問題,如果有三個特徵,分類線就變成了分類平面,多個特徵的話就變成了超平面。從這個角度出發去看待SVM,會比較輕鬆。

數學解決方法大致如下:
目的是求最大分隔平面,也就是選取靠近平面最近的點,使這些點到分隔平面的距離W最大,是一個典型的凸二次規劃問題。
這裡寫圖片描述
但是上面需要求解兩個引數w和b;於是為求解這個問題,把二次規劃問題轉換為對偶問題
這裡寫圖片描述
這樣就只需求一個引數a了,通過SMO演算法求出a後,再計算出b
這裡寫圖片描述
最後通過f(x)用做預測。

python程式碼實現可以加深對那些數學推導公式的印象,看公式的時候,可能會想,這些推導好複雜,都有些什麼用啊,結果寫程式碼的時候會發現,原來最後都用在程式碼裡。所以寫程式碼可以加深對SVM的理解。
下面是SVM的python程式碼實現,我做了詳細的註釋,剛開始看程式碼也會覺得好長好複雜,慢慢看後發現,程式碼就是照著SVM的數學推導,把最後的公式推導轉化為程式碼和程式的邏輯,程式碼本身並不複雜。

from numpy import * 

def loadDataSet(filename): #讀取資料
    dataMat=[]
    labelMat=[]
    fr=open(filename)
    for line in fr.readlines():
        lineArr=line.strip().split('\t')
        dataMat.append([float(lineArr[0]),float(lineArr[1])])
        labelMat.append(float(lineArr[2]))
    return dataMat,labelMat #返回資料特徵和資料類別
def selectJrand(i,m): #在0-m中隨機選擇一個不是i的整數 j=i while (j==i): j=int(random.uniform(0,m)) return j def clipAlpha(aj,H,L): #保證a在L和H範圍內(L <= a <= H) if aj>H: aj=H if L>aj: aj=L return aj def kernelTrans(X, A, kTup): #核函式,輸入引數,X:支援向量的特徵樹;A:某一行特徵資料;kTup:('lin',k1)核函式的型別和引數 m,n = shape(X) K = mat(zeros((m,1))) if kTup[0]=='lin': #線性函式 K = X * A.T elif kTup[0]=='rbf': # 徑向基函式(radial bias function) for j in range(m): deltaRow = X[j,:] - A K[j] = deltaRow*deltaRow.T K = exp(K/(-1*kTup[1]**2)) #返回生成的結果 else: raise NameError('Houston We Have a Problem -- That Kernel is not recognized') return K #定義類,方便儲存資料 class optStruct: def __init__(self,dataMatIn, classLabels, C, toler, kTup): # 儲存各類引數 self.X = dataMatIn #資料特徵 self.labelMat = classLabels #資料類別 self.C = C #軟間隔引數C,引數越大,非線性擬合能力越強 self.tol = toler #停止閥值 self.m = shape(dataMatIn)[0] #資料行數 self.alphas = mat(zeros((self.m,1))) self.b = 0 #初始設為0 self.eCache = mat(zeros((self.m,2))) #快取 self.K = mat(zeros((self.m,self.m))) #核函式的計算結果 for i in range(self.m): self.K[:,i] = kernelTrans(self.X, self.X[i,:], kTup) def calcEk(oS, k): #計算Ek(參考《統計學習方法》p127公式7.105) fXk = float(multiply(oS.alphas,oS.labelMat).T*oS.K[:,k] + oS.b) Ek = fXk - float(oS.labelMat[k]) return Ek #隨機選取aj,並返回其E值 def selectJ(i, oS, Ei): maxK = -1 maxDeltaE = 0 Ej = 0 oS.eCache[i] = [1,Ei] validEcacheList = nonzero(oS.eCache[:,0].A)[0] #返回矩陣中的非零位置的行數 if (len(validEcacheList)) > 1: for k in validEcacheList: if k == i: continue Ek = calcEk(oS, k) deltaE = abs(Ei - Ek) if (deltaE > maxDeltaE): #返回步長最大的aj maxK = k maxDeltaE = deltaE Ej = Ek return maxK, Ej else: j = selectJrand(i, oS.m) Ej = calcEk(oS, j) return j, Ej def updateEk(oS, k): #更新os資料 Ek = calcEk(oS, k) oS.eCache[k] = [1,Ek] #首先檢驗ai是否滿足KKT條件,如果不滿足,隨機選擇aj進行優化,更新ai,aj,b值 def innerL(i, oS): #輸入引數i和所有引數資料 Ei = calcEk(oS, i) #計算E值 if ((oS.labelMat[i]*Ei < -oS.tol) and (oS.alphas[i] < oS.C)) or ((oS.labelMat[i]*Ei > oS.tol) and (oS.alphas[i] > 0)): #檢驗這行資料是否符合KKT條件 參考《統計學習方法》p128公式7.111-113 j,Ej = selectJ(i, oS, Ei) #隨機選取aj,並返回其E值 alphaIold = oS.alphas[i].copy() alphaJold = oS.alphas[j].copy() if (oS.labelMat[i] != oS.labelMat[j]): #以下程式碼的公式參考《統計學習方法》p126 L = max(0, oS.alphas[j] - oS.alphas[i]) H = min(oS.C, oS.C + oS.alphas[j] - oS.alphas[i]) else: L = max(0, oS.alphas[j] + oS.alphas[i] - oS.C) H = min(oS.C, oS.alphas[j] + oS.alphas[i]) if L==H: print("L==H") return 0 eta = 2.0 * oS.K[i,j] - oS.K[i,i] - oS.K[j,j] #參考《統計學習方法》p127公式7.107 if eta >= 0: print("eta>=0") return 0 oS.alphas[j] -= oS.labelMat[j]*(Ei - Ej)/eta #參考《統計學習方法》p127公式7.106 oS.alphas[j] = clipAlpha(oS.alphas[j],H,L) #參考《統計學習方法》p127公式7.108 updateEk(oS, j) if (abs(oS.alphas[j] - alphaJold) < oS.tol): #alpha變化大小閥值(自己設定) print("j not moving enough") return 0 oS.alphas[i] += oS.labelMat[j]*oS.labelMat[i]*(alphaJold - oS.alphas[j])#參考《統計學習方法》p127公式7.109 updateEk(oS, i) #更新資料 #以下求解b的過程,參考《統計學習方法》p129公式7.114-7.116 b1 = oS.b - Ei- oS.labelMat[i]*(oS.alphas[i]-alphaIold)*oS.K[i,i] - oS.labelMat[j]*(oS.alphas[j]-alphaJold)*oS.K[i,j] b2 = oS.b - Ej- oS.labelMat[i]*(oS.alphas[i]-alphaIold)*oS.K[i,j]- oS.labelMat[j]*(oS.alphas[j]-alphaJold)*oS.K[j,j] if (0 < oS.alphas[i]<oS.C): oS.b = b1 elif (0 < oS.alphas[j]<oS.C): oS.b = b2 else: oS.b = (b1 + b2)/2.0 return 1 else: return 0 #SMO函式,用於快速求解出alpha def smoP(dataMatIn, classLabels, C, toler, maxIter,kTup=('lin', 0)): #輸入引數:資料特徵,資料類別,引數C,閥值toler,最大迭代次數,核函式(預設線性核) oS = optStruct(mat(dataMatIn),mat(classLabels).transpose(),C,toler, kTup) iter = 0 entireSet = True alphaPairsChanged = 0 while (iter < maxIter) and ((alphaPairsChanged > 0) or (entireSet)): alphaPairsChanged = 0 if entireSet: for i in range(oS.m): #遍歷所有資料 alphaPairsChanged += innerL(i,oS) print("fullSet, iter: %d i:%d, pairs changed %d" % (iter,i,alphaPairsChanged)) #顯示第多少次迭代,那行特徵資料使alpha發生了改變,這次改變了多少次alpha iter += 1 else: nonBoundIs = nonzero((oS.alphas.A > 0) * (oS.alphas.A < C))[0] for i in nonBoundIs: #遍歷非邊界的資料 alphaPairsChanged += innerL(i,oS) print("non-bound, iter: %d i:%d, pairs changed %d" % (iter,i,alphaPairsChanged)) iter += 1 if entireSet: entireSet = False elif (alphaPairsChanged == 0): entireSet = True print("iteration number: %d" % iter) return oS.b,oS.alphas def testRbf(data_train,data_test): dataArr,labelArr = loadDataSet(data_train) #讀取訓練資料 b,alphas = smoP(dataArr, labelArr, 200, 0.0001, 10000, ('rbf', 1.3)) #通過SMO演算法得到b和alpha datMat=mat(dataArr) labelMat = mat(labelArr).transpose() svInd=nonzero(alphas)[0] #選取不為0資料的行數(也就是支援向量) sVs=datMat[svInd] #支援向量的特徵資料 labelSV = labelMat[svInd] #支援向量的類別(1或-1) print("there are %d Support Vectors" % shape(sVs)[0]) #打印出共有多少的支援向量 m,n = shape(datMat) #訓練資料的行列數 errorCount = 0 for i in range(m): kernelEval = kernelTrans(sVs,datMat[i,:],('rbf', 1.3)) #將支援向量轉化為核函式 predict=kernelEval.T * multiply(labelSV,alphas[svInd]) + b #這一行的預測結果(程式碼來源於《統計學習方法》p133裡面最後用於預測的公式)注意最後確定的分離平面只有那些支援向量決定。 if sign(predict)!=sign(labelArr[i]): #sign函式 -1 if x < 0, 0 if x==0, 1 if x > 0 errorCount += 1 print("the training error rate is: %f" % (float(errorCount)/m)) #打印出錯誤率 dataArr_test,labelArr_test = loadDataSet(data_test) #讀取測試資料 errorCount_test = 0 datMat_test=mat(dataArr_test) labelMat = mat(labelArr_test).transpose() m,n = shape(datMat_test) for i in range(m): #在測試資料上檢驗錯誤率 kernelEval = kernelTrans(sVs,datMat_test[i,:],('rbf', 1.3)) predict=kernelEval.T * multiply(labelSV,alphas[svInd]) + b if sign(predict)!=sign(labelArr_test[i]): errorCount_test += 1 print("the test error rate is: %f" % (float(errorCount_test)/m)) #主程式 def main(): filename_traindata='C:\\Users\\Administrator\\Desktop\\data\\traindata.txt' filename_testdata='C:\\Users\\Administrator\\Desktop\\data\\testdata.txt' testRbf(filename_traindata,filename_testdata) if __name__=='__main__': main()

樣例資料如下:
這裡寫圖片描述
訓練資料:train_data

-0.214824   0.662756    -1.000000
-0.061569   -0.091875   1.000000
0.406933    0.648055    -1.000000
0.223650    0.130142    1.000000
0.231317    0.766906    -1.000000
-0.748800   -0.531637   -1.000000
-0.557789   0.375797    -1.000000
0.207123    -0.019463   1.000000
0.286462    0.719470    -1.000000
0.195300    -0.179039   1.000000
-0.152696   -0.153030   1.000000
0.384471    0.653336    -1.000000
-0.117280   -0.153217   1.000000
-0.238076   0.000583    1.000000
-0.413576   0.145681    1.000000
0.490767    -0.680029   -1.000000
0.199894    -0.199381   1.000000
-0.356048   0.537960    -1.000000
-0.392868   -0.125261   1.000000
0.353588    -0.070617   1.000000
0.020984    0.925720    -1.000000
-0.475167   -0.346247   -1.000000
0.074952    0.042783    1.000000
0.394164    -0.058217   1.000000
0.663418    0.436525    -1.000000
0.402158    0.577744    -1.000000
-0.449349   -0.038074   1.000000
0.619080    -0.088188   -1.000000
0.268066    -0.071621   1.000000
-0.015165   0.359326    1.000000
0.539368    -0.374972   -1.000000
-0.319153   0.629673    -1.000000
0.694424    0.641180    -1.000000
0.079522    0.193198    1.000000
0.253289    -0.285861   1.000000
-0.035558   -0.010086   1.000000
-0.403483   0.474466    -1.000000
-0.034312   0.995685    -1.000000
-0.590657   0.438051    -1.000000
-0.098871   -0.023953   1.000000
-0.250001   0.141621    1.000000
-0.012998   0.525985    -1.000000
0.153738    0.491531    -1.000000
0.388215    -0.656567   -1.000000
0.049008    0.013499    1.000000
0.068286    0.392741    1.000000
0.747800    -0.066630   -1.000000
0.004621    -0.042932   1.000000
-0.701600   0.190983    -1.000000
0.055413    -0.024380   1.000000
0.035398    -0.333682   1.000000
0.211795    0.024689    1.000000
-0.045677   0.172907    1.000000
0.595222    0.209570    -1.000000
0.229465    0.250409    1.000000
-0.089293   0.068198    1.000000
0.384300    -0.176570   1.000000
0.834912    -0.110321   -1.000000
-0.307768   0.503038    -1.000000
-0.777063   -0.348066   -1.000000
0.017390    0.152441    1.000000
-0.293382   -0.139778   1.000000
-0.203272   0.286855    1.000000
0.957812    -0.152444   -1.000000
0.004609    -0.070617   1.000000
-0.755431   0.096711    -1.000000
-0.526487   0.547282    -1.000000
-0.246873   0.833713    -1.000000
0.185639    -0.066162   1.000000
0.851934    0.456603    -1.000000
-0.827912   0.117122    -1.000000
0.233512    -0.106274   1.000000
0.583671    -0.709033   -1.000000
-0.487023   0.625140    -1.000000
-0.448939   0.176725    1.000000
0.155907    -0.166371   1.000000
0.334204    0.381237    -1.000000
0.081536    -0.106212   1.000000
0.227222    0.527437    -1.000000
0.759290    0.330720    -1.000000
0.204177    -0.023516   1.000000
0.577939    0.403784    -1.000000
-0.568534   0.442948    -1.000000
-0.011520   0.021165    1.000000
0.875720    0.422476    -1.000000
0.297885    -0.632874   -1.000000
-0.015821   0.031226    1.000000
0.541359    -0.205969   -1.000000
-0.689946   -0.508674   -1.000000
-0.343049   0.841653    -1.000000
0.523902    -0.436156   -1.000000
0.249281    -0.711840   -1.000000
0.193449    0.574598    -1.000000
-0.257542   -0.753885   -1.000000
-0.021605   0.158080    1.000000
0.601559    -0.727041   -1.000000
-0.791603   0.095651    -1.000000
-0.908298   -0.053376   -1.000000
0.122020    0.850966    -1.000000
-0.725568   -0.292022   -1.000000

測試資料:test_data

0.676771    -0.486687   -1.000000
0.008473    0.186070    1.000000
-0.727789   0.594062    -1.000000
0.112367    0.287852    1.000000
0.383633    -0.038068   1.000000
-0.927138   -0.032633   -1.000000
-0.842803   -0.423115   -1.000000
-0.003677   -0.367338   1.000000
0.443211    -0.698469   -1.000000
-0.473835   0.005233    1.000000
0.616741    0.590841    -1.000000
0.557463    -0.373461   -1.000000
-0.498535   -0.223231   -1.000000
-0.246744   0.276413    1.000000
-0.761980   -0.244188   -1.000000
0.641594    -0.479861   -1.000000
-0.659140   0.529830    -1.000000
-0.054873   -0.238900   1.000000
-0.089644   -0.244683   1.000000
-0.431576   -0.481538   -1.000000
-0.099535   0.728679    -1.000000
-0.188428   0.156443    1.000000
0.267051    0.318101    1.000000
0.222114    -0.528887   -1.000000
0.030369    0.113317    1.000000
0.392321    0.026089    1.000000
0.298871    -0.915427   -1.000000
-0.034581   -0.133887   1.000000
0.405956    0.206980    1.000000
0.144902    -0.605762   -1.000000
0.274362    -0.401338   1.000000
0.397998    -0.780144   -1.000000
0.037863    0.155137    1.000000
-0.010363   -0.004170   1.000000
0.506519    0.486619    -1.000000
0.000082    -0.020625   1.000000
0.057761    -0.155140   1.000000
0.027748    -0.553763   -1.000000
-0.413363   -0.746830   -1.000000
0.081500    -0.014264   1.000000
0.047137    -0.491271   1.000000
-0.267459   0.024770    1.000000
-0.148288   -0.532471   -1.000000
-0.225559   -0.201622   1.000000
0.772360    -0.518986   -1.000000
-0.440670   0.688739    -1.000000
0.329064    -0.095349   1.000000
0.970170    -0.010671   -1.000000
-0.689447   -0.318722   -1.000000
-0.465493   -0.227468   -1.000000
-0.049370   0.405711    1.000000
-0.166117   0.274807    1.000000
0.054483    0.012643    1.000000
0.021389    0.076125    1.000000
-0.104404   -0.914042   -1.000000
0.294487    0.440886    -1.000000
0.107915    -0.493703   -1.000000
0.076311    0.438860    1.000000
0.370593    -0.728737   -1.000000
0.409890    0.306851    -1.000000
0.285445    0.474399    -1.000000
-0.870134   -0.161685   -1.000000
-0.654144   -0.675129   -1.000000
0.285278    -0.767310   -1.000000
0.049548    -0.000907   1.000000
0.030014    -0.093265   1.000000
-0.128859   0.278865    1.000000
0.307463    0.085667    1.000000
0.023440    0.298638    1.000000
0.053920    0.235344    1.000000
0.059675    0.533339    -1.000000
0.817125    0.016536    -1.000000
-0.108771   0.477254    1.000000
-0.118106   0.017284    1.000000
0.288339    0.195457    1.000000
0.567309    -0.200203   -1.000000
-0.202446   0.409387    1.000000
-0.330769   -0.240797   1.000000
-0.422377   0.480683    -1.000000
-0.295269   0.326017    1.000000
0.261132    0.046478    1.000000
-0.492244   -0.319998   -1.000000
-0.384419   0.099170    1.000000
0.101882    -0.781145   -1.000000
0.234592    -0.383446   1.000000
-0.020478   -0.901833   -1.000000
0.328449    0.186633    1.000000
-0.150059   -0.409158   1.000000
-0.155876   -0.843413   -1.000000
-0.098134   -0.136786   1.000000
0.110575    -0.197205   1.000000
0.219021    0.054347    1.000000
0.030152    0.251682    1.000000
0.033447    -0.122824   1.000000
-0.686225   -0.020779   -1.000000
-0.911211   -0.262011   -1.000000
0.572557    0.377526    -1.000000
-0.073647   -0.519163   -1.000000
-0.281830   -0.797236   -1.000000
-0.555263   0.126232    -1.000000

參考:
《統計學習方法》
《Machine Learning in Action》