支援向量機SMO演算法實現(原始碼逐條解釋)
支援向量機號稱機器學習中最好的演算法——存在最優解,而且一般問題都可以得解。但是演算法需要的儲存空間和計算複雜度較大,不大適合大資料量的運算,不過經過platt發明的SMO簡化運算後,效率可以提高很多。以下是筆者用Matlab語言寫的支援向量機兩分類問題的原始碼,因為在網路上得到各位前輩的指點,受益頗多,因此不敢藏私,分享如下,也歡迎有志於機器學習、神經網路、模式識別和時間序列預測的各位一起討論,共同完善,謝了先。
支援向量機的理論、原理可以參考一下連結,並藉此感謝原文作者:July、pluskid ;
http://blog.csdn.net/v_july_v/article/details/7624837
這是雙月兩分類問題的分類結果,訓練樣本300個點,驗證樣本2000個點;紅色為決策超平面,紅色*點是支援向量,分類誤差5個點,誤差率0.25%,考慮到這是一種不可分模式,這樣的誤差可以接受了。
言歸正傳,原始碼如下:
%SMO algorithm solve the SVM optimization question;
%build the inner kernel at first, since the optimization algorithm is the
%easiest.
%then develop to the selection method;
%-------------------20140204----------------------------------------------
%revise the stop judgement:whether the alpha parameters and related points
%meet the KKT conditions. no obvious improvement.
%17:16 try another method: C-SVM;
%at first, C=100;
%and at the same time, I need efxt;
%-------------------------------------------------------------------------
%{
上面是對這個問題的大致說明,我的方法是由計算核心(核函式)開始,逐步往上構建整個演算法,由於運算涉及到多次使用核函式,因此如何簡化核函式、避免重複運算,成為提升效率、優化程式碼的關鍵;
當然,在這之前,我也寫下了整個程式流程圖,這讓我知道每一步的意義,也知道在程式設計的過程中,程式碼處於什麼位置,又將向何處去。
%}
clear all;
close all;
%關閉圖形和清除記憶體中各引數;
st = cputime; %record current time; 這個是為了統計整體計算時間;
%-------------------------------------------------
% set the initial value;
%-------------------------------------------------
N=300; %the training data set size; 300個訓練點;
NN=2000; %the number of validation data set; 2000個驗證點;
NS=N; %number of SVs. 支援向量的數目,一開始假設所有資料點都是支援向量;
%------------------------------------------------
%produce the training set
%------------------------------------------------
fprintf('Producing the training set...\n');
fprintf('-----------------------------------------\n');
%運算中我們需要知道程式進行到哪一步了,這就需要在程式碼中嵌入這些指示性文字;
r=10; %半徑=10;
wth=6; %寬度=6;
dis=-6.5; %距離=-6;
si=0.085;
%構建雙月形狀,設定半徑、寬度、距離、以及核函式的方差,這個方差不是一蹴而就的,可以在第一次運算後,通過統計支援向量之間的最大距離和支援向量數目運算得出;
for i=1:N+NN; %at the beginning of every round, to produce the training set in random; 看了Yanbo Xue的程式,他是一次性產生3000個數據點,
%然後用這三千個點中的一千個反覆訓練,訓練50個回合;先試試他的方法;
%make the rand vectors X;
if rand>0.5; %make sure the points occur in random. and half in half for each class;
theta=pi*rand; %set the angle;
lenth=r+wth*(rand-0.5); %the range;
h(i)=lenth*cos(theta); %the X axis;
ve(i)=lenth*sin(theta); %the Y axis;
y(i)=-1; %the expected response; class 1;
else
beta=-pi*rand;
len_2=r+wth*(rand-0.5);
h(i)=len_2*cos(beta)+r;
ve(i)=len_2*sin(beta)-dis;
y(i)=1; %class 2;
end
end
%這裡是產生雙月形狀的程式碼,主要到數量是N+NN,也就是說訓練集和驗證集產生於同一個模型;
fprintf('Normalization...\n');
fprintf('-----------------------------------------\n');
%the input data need to deal with;去均值和正則化?
% and use the random data.
%首先去掉均值;
miu_x=mean(h); %取得x1的均值;
miu_y=mean(ve); %取得x2的均值;
h=h-miu_x*ones(1,NN+N); %減去均值;相當於減去了直流分量;
ve=ve-miu_y*ones(1,NN+N); %減去均值;
%正則化?高斯化?讓取值落在[0,1]區間?
max_h=max(abs(h)); %取得絕對值的最大值;
max_ve=max(abs(ve));
h=h./max_h; %除以這個數;
ve=ve./max_ve;
tr_set=[h;ve;y]; %組合成一個訓練集,響應沒有變化;
%X=[h;ve]; %只取資料點,不取標號,但不打亂順序,這樣可以恢復出響應來;
seq_tr_set=randperm(N); %重新洗牌;
%---------------------------------------------------------------------------------------------------------------
%注意到這裡,randperm的意思是打亂次序,也就是隨機取點,不受產生順序的限制;
%---------------------------------------------------------------------------------------------------------------
mix_tr_set=tr_set(:,seq_tr_set); %注意到從2300個點中取得300個;
h=mix_tr_set(1,:); %第一行是x軸資料,或x1;
ve=mix_tr_set(2,:);
y=mix_tr_set(3,:); %注意到響應是跟著資料點走的;
%tr_set_1=[h;ve;y];
X=[h;ve]; %某些時候只需要用到兩列,但是響應還是隨著資料點走的;
plot(h,ve,'.');
%以上步驟是把產生的資料歸一化,這樣便於處理;這樣資料已經產生完畢;
fprintf('Now we start the SVM by SMO algorithm ...\n');
fprintf('-----------------------------------------\n');
%從這裡開始,我們準備構建SMO演算法;
%------------------------------------------------------
%set the SVM initial value;
%------------------------------------------------------
a=zeros(1,N); %the first initial value; 這就是著名的拉格朗日運算元,每個資料點都有一個;
b=0; %bias. 偏置;
C=100; %the boundry; 懲罰運算元;
FX=zeros(1,N); %f(x(i)) i=1toN, since {a(i)}=0, so f(x)=0; f(x)=w'*x-b; or f(x)=SUM ai*yi*K(xi,xj)-b=0; f(xi)函式,一般有f(xi)=w'×xi-b;這裡是f(xi)=sum ai*yi*K(xi,xj)-b;
E=FX-y; %y: the expected response; 這是函式f(xi)與期望響應(或者叫做標號)的差異;
%??E should be followed by training set;
efxt=a./C; %efxt is the gap distance related parameters; 這個是對於那些運算元=C的點對應的間隔;
eps=0.007; %threshold for stopping judgement; 閾值,(原目標-對偶目標)/(原目標+1)的比值對應停機條件,小於這個閾值,運算停止;
times=0; %at the start, times=0;the number of 外迴圈運算多少次了;
%external loops;
presv=0; %at the beginning, suppose there is no 相當於flag,表示是否找到不滿足KKT條件的支援向量;
%support vector found;
Gram=eye(N,N); %build the gram matrix, for recording the calculation results;
%Gram矩陣,我們不需要運算矩陣中所有的點,但是對於已經計算了的點,我們無需重複計算;注意到採用的是eye矩陣,這樣對角資料全為1;
%這樣有兩個用意,第一,資料確實為1,第二,其他資料為0,這樣可以判斷這個位置是否運算過了;
ot=0; %指示是否運算數目太多,迭代時間太長了;
totaltimes=0; %how many times of calculation operated? 一共執行多少次兩運算元計算了;
in=1; %set the initial value of in;in start from 1; 為了區分違反KKT條件的三種運算元,有三種指示,a=0;a=C;C>a>0;
ic=1;
i0=1;
while(1) %the first loop, when ratio<eps, loop stop; 開始外迴圈,採用啟發式方法選擇第一個點;
%select the first point;
%--------------------------------------------------------------------------
if (times==0) %first selection or calculation; 如果演算法初次執行,隨意選擇一個點,這裡選擇點1;
i1=1; %select the point in random, so chose 1;
else %if this is not the first selection:
%----------------------------------------------------------------------
%when other selection:
%1, choose the (0,C) alpha which break the KKT conditions;
%2, then alpha=0 and alpha=C points which break the KKT conditions;
%3, if alpha choosen is in the last of the squene, re-start from 1;
%----------------------------------------------------------------------
while (in<=N) %search in all (0,C)alpha; 以後的運算,首先考慮那些界內的、不滿足KKT條件的運算元;
if (a(in)>0) && (a(in)<C) %here, C is 100.
if (y(in)*FX(in)>1.01)||(y(in)*FX(in)<0.99) %meet the KKT condition or not: 這裡不能用不等於1來表示違反KKT條件,而是用一個容許範圍,取0.99~1.01;
i1=in; %if break the KKT rule, i1=i;
presv=1; %we found a SV already; 如果滿足,記錄下當前序號,並改變flag標誌;
end
end
in=in+1; %if we don't found the break KKT condition SV, continue... 迴圈進行;
if (presv) %we've found SV and not exceed the NS times; 這裡採用的是do while結構,所以採用了一個判決語句,滿足條件時退出;
break;
end
end
if presv==0 %if we don't found SV which break KKT condition, or, the times_i1 out of NS times;
%如果界內的運算元都滿足KKT條件,考慮界上的運算元(a=0或a=C)
in=1; %及時對界內運算元的序號復位,這裡復位為1;
while (ic<=N)
%if the alpha on the boundry,i.e. a=0 or C;
if a(ic)==C %這裡首先考慮a=C的運算元,因為這也是支援向量;
if (y(ic)*FX(ic)>1.01) %注意到這裡也採用了容許範圍;
i1=ic; %if break the KKT rule, i1=i;
presv=1; %we found a SV already;
end
end
ic=ic+1;
if (presv)
break;
end
end
end
if presv==0
in=1;
while (i0<=N)
%if the alpha on the boundry,i.e. a=0 or C;
if a(i0)==0 %現在考慮a=0的運算元,這些運算元是非常多的;這些就不是支援向量了;
if (y(i0)*FX(i0)<.99)
i1=i0; %if break the KKT rule, i1=i;
presv=1; %we found a SV already;
end
end
i0=i0+1;
if (presv)
break;
end
end
end
if (presv==0) %if we didn't found a SV which break KKT;
in=1;
ic=1;
i0=1;
i1=floor(rand*N)+1; %如果所有的支援向量都滿足KKT條件,那麼隨意選擇一個運算元;
end
presv=0; %back to the initial value;i.e.no sv found; 標誌位清零;
end
%--------------------------------------------------------------------------
%{
--------------------------------------------------------
now I have the first point, and search the second point.
第一個運算元找到之後,開始第二個運算元;
--------------------------------------------------------
%}
max_i=1;
times_2=0; %how many times of internal loop?
i2old=i1; %the initial value of i2;
while(1) %the important question is how to
%stop the loop and this is internal
%loop, too.
if (times==0) %at the first time, most points have 初次運算,所有的運算元都要更新一下;
%a(i)=0;
if i2old<N %i2old from i1:N; i.e. 1:N.
i2=i2old+1; %i2 from 2:N+1;
i2old=i2; %remember the choice;
end
else %at the other loops,there should be 在之後的運算中,我們要尋找那些|E1-E2|最大的點;
%some points a(i)>0;
min_E=E(i1);
max_E=E(i1);
if (E(i1)>=0)
for i=1:N
if (a(i)>0) && (a(i)<C)
if (E(i)<min_E)%&&(i~=i2old_x)
min_E=E(i);
i2=i;
%i2old_x=i2;
end
end
end
else
for i=1:N
if (a(i)>0) && (a(i)<C)
if (E(i)>max_E)%&&(i~=i2old_m)
max_E=E(i);
i2=i;
%i2old_m=i2;
end
end
end
end
end
%end
times_2=times_2+1; %totally N for times=0;
if times_2>=NS %finish one complete loop;
break;
end
totaltimes=totaltimes+1; %total times of calculation;
%---------------------------------------------------------------
%now I have two points, start the calculation;
%---------------------------------------------------------------
%now we get the i2, so we could start the optimization;
%select the second point;
%SMO optimization algorithm; 從現在開始就是SMO的運算了,每次更新兩個點,然後將更新後的Fx1, Fx2上傳;
%{
possible parameters;
SV(i); collect all support vectors into one group and calculate the E1&E2
using these elements;
x1,x2;
x_sv(i) and y_sv(i) compare to the SV(i);
---------------------------------------------
maybe we don't need to collect the SV group.
we just choose the {a(i)>0}is okay,
when the algorithm stops, we could update the
SVs and give the final decision hyper plane.
---------------------------------------------
a1,a2;
a2new,a2new-unc,a2old;a1new,a1old;
y1,y2;
E1,E2;
L,H;
K11,K22,K12;
%}
%and we need the kernel function;
%{
K(x,z)=exp(-1/(2*sigma^2)*norm(x-z)^2);
%}
%--------------------------------------------------------------------------
%
%RUN the SMO algorithm
%
%--------------------------------------------------------------------------
x1=X(:,i1); %這裡都比較容易看懂了,只解釋比較隱晦的部分;
x2=X(:,i2);
y1=y(i1);
y2=y(i2);
a1old=a(i1);
a2old=a(i2);
%when C=inf, the limits should be adjusted as below.
if(y1~=y2)
L=max(0,a2old-a1old);
H=min(C,C-a1old+a2old);
else
L=max(0,a1old+a2old-C);
H=min(C,a1old+a2old);
end
if times~=0 && L>=H %first time, we should let L=H=0;
break;
end
K11=1;%K(x1,x1); for Guass kernel function;
K22=1;%K(x2,x2); ditto;
if Gram(i1,i2)~=0 %注意到這裡的選擇機制,一旦Gram中的元素不等於0,就說明這兩個點已經運算過,可以跳過了;
K12=Gram(i1,i2);
else
K12=K(x1,x2,si);
Gram(i1,i2)=K12;
end
%we need two loop for E1&E2 calculation;
%l, the number of support vectors; at the first, this value is 0, then grow
%slowly.
%I need the mapping between support vectors and original points.
%set N(i)=number of the original points, that means the pointer. such as
%N(1)=30, means the first support vector is the 30th points.
%so the pointer of original points is important.
%and I want to know the verse # of the SV, i.e.O(30)=1, means the 30th
%point in the original points, is the first SV.
E1=E(i1); %this value got from memory;
E2=E(i2); %ditto;
k=K11+K22-2*K12; %parameter couldn't be 0 when it worked as divider.
if k==0 %the possibility is rare.
k=0.01;
end
%?? how about when k=0???!!!!!----------------------
%---------------------------------------------------
a2new=a2old+y2*(E1-E2)/k; %這就是運算元的更新運算了;
if (a2new>H)
a2new=H;
else
if(a2new<L)
a2new=L;
end
end
a1new=a1old+y1*y2*(a2old-a2new);
a1new=max(0,a1new); %運算元a1的更新;
%{
if abs(a2new-a2old)<(eps*(a2new+a2old-eps)) %if the difference is little, jump out of the loop;
break;
end
%}
a(i1)=a1new; %now we could update the{a(i)}i=1toN
a(i2)=a2new; %update the {a(i)};
%now we starts update the bias;
%--------------------------------------------------------------------------
bold=b;
a1e=y1*(a1new-a1old);
a2e=y2*(a2new-a2old);
b1new=E1+a1e+a2e*K12+bold; %偏置的更新;
b2new=E2+a1e*K12+a2e+bold;
if a1new>0 && a1new<C %if a1new is in the bounds;
b=b1new;
else
if a2new>0 && a2new<C
b=b2new;
else
if (a1new==0||a1new==C)&&(a2new==0||a2new==C)&&(L~=H)
b=(b1new+b2new)/2;
end
end
end
%--------------------------------------------------------------------------
%更新F(x1),F(x2);
FX1=0;
FX2=0;
for (i=1:N)
if a(i)>0
if Gram(i,i1)~=0;
Ki1=Gram(i,i1);
else
Ki1=K(X(:,i),x1,si);
Gram(i,i1)=Ki1;
end
FX1=FX1+a(i)*y(i)*Ki1;
if Gram(i,i2)~=0;
Ki2=Gram(i,i2);
else
Ki2=K(X(:,i),x2,si);
Gram(i,i2)=Ki2;
end
FX2=FX2+a(i)*y(i)*Ki2;
end
end
FX(i1)=FX1-b; %FX=SUM ai*yi*K(xi,xj)-b;
FX(i2)=FX2-b;
E(i1)=FX(i1)-y1; %store the E(i) into the E matrix;
E(i2)=FX(i2)-y2;
%N, the number of training set;
%now we calculate the stop formula and judge whether the algorithm should
%stop.
end
%---------------------------------------------------------------
%one internal loops complete, times++ for external loop;
%---------------------------------------------------------------
%C=max(a)+1;
%--------------------------------------------------------------------------
% I hope that calculation could be complete less than (n-1)*(n-2)/2 times;
%--------------------------------------------------------------------------
if totaltimes>100*N %運算時間過長、迭代次數太多,演算法自動終止;
fprintf('how many knives?\n');
fprintf('-----------------------------------------\n');
break;
end
times=times+1; %times++, i.e. outer loop ++;
%do we need update the E(i) and FX(i)? I think so. FX(i) is neccesary, E(i)
%don't need.
%now I hesitate.
%這裡開始記錄所有的支援向量,並更新所有的F(xi),i=1:N;
i=1;
l=0;
while i<=N
if a(i)>0
l=l+1;
SV(l)=a(i);
y_sv(l)=y(i);
x_sv(:,l)=X(:,i);
ptr(l)=i; %remember the pointer;
end
i=i+1;
end
NS=l; %the number of support vectors.
%{
if NS<=(N/10)
p=0;
for j=1:NS;
for i=1:NS;
d_max=norm(x_sv(:,i)-x_sv(:,j));
if d_max>p;
p=d_max;
end
end
end
d_max=p;
si=d_max^2/(2*NS);
end
%}
%{
lold=1;
while (a(lold)==C)
lold=lold+1;
end
FV=0;
for l=1:NS
FV=FV+SV(l)*y_sv(l)*K(x_sv(:,l),x_sv(:,lold));
end
b=FV-y_sv(lold);
lold=lold+1;
if lold>NS
lold=1;
end
%}
FX=zeros(1,N);
for i=1:N;
for l=1:NS
if Gram(ptr(l),i)~=0;
Kli=Gram(ptr(l),i);
else
Kli=K(x_sv(:,l),X(:,i),si);
Gram(ptr(l),i)=Kli;
end
FX(i)=FX(i)+SV(l)*y_sv(l)*Kli;
end
end
sv_x=x_sv; %this vector for demostation.
x_sv=zeros(2,NS); %clear past data for preparation.
FX=FX-b*ones(1,N);
E=FX-y;
for i=1:N
efxt(i)=max(0,1-y(i)*FX(i));
end
%-------------------------------------
%Now we revise the stop criteria.這裡是關於停機條件的運算;
%-------------------------------------
W=0;
AAYYK=0;
for i=1:N
if (a(i)>0)
for j=1:N
if (a(j)>0)
if Gram(i,j)~=0
Kij=Gram(i,j);
else
Kij=K(X(:,i),X(:,j),si);
Gram(i,j)=Kij;
end
AAYYK=AAYYK+a(i)*a(j)*y(i)*y(j)*Kij;
end
end
end
end
suma=sum(a);
sume=C*sum(efxt);
W=suma-0.5*AAYYK;
ratio=((suma-2*W+sume)/(suma-W+sume+1)); %比值接近於0才行的;
fprintf('Now we can see the difference between original objective and dual one...\n');
fprintf('Ratio = %f\n',ratio);
if (ratio<eps)
break; %if ratio meet the stop threshold, loop
%stops.
end
end %this end compare to while(1);
%now the SV calculation is complete.
%and there should be some SVs occured, now I can collect the SV group,
%check the validation data set and draw the decision hyper plane.
%運算結束後,開始記錄支援向量;
i=1;
l=0;
while i<=N
if a(i)>0
l=l+1;
SV(l)=a(i);
y_sv(l)=y(i);
x_sv(:,l)=X(:,i);
ptr(l)=i;
end
i=i+1;
end
NS=l; %the number of support vectors.
%now we should update the value of b.!!!!!!更新偏置;
lold=floor(rand*NS)+1;
while (a(lold)==C)
lold=lold+1;
end
FV=0;
for l=1:NS
if Gram(ptr(l),ptr(lold))~=0
Kllo=Gram(ptr(l),ptr(lold));
else
Kllo=K(x_sv(:,l),x_sv(:,lold),si);
Gram(ptr(l),ptr(lold))=Kllo;
end
FV=FV+SV(l)*y_sv(l)*Kllo;
end
b=FV-y_sv(lold);
%now I collect all support vectors and related expected response, input;
%now I can check the validation set.驗證;
% this module for producing the validation set.
%------------------------------------------------------------------
seq_tr_set=randperm(NN); %重新洗牌;
mix_tr_set=tr_set(:,seq_tr_set); %注意到從3000個點中取得1000個;
h=mix_tr_set(1,:); %第一行是x軸資料,或x1;
ve=mix_tr_set(2,:);
y=mix_tr_set(3,:); %注意到響應是跟著資料點走的;
tr_set=[h;ve;y];
X=[h;ve]; %某些時候只需要用到兩列,但是響應還是隨著資料點走的;
%------------------------------------------------------------------
% and do some normlization for this set.
err=0;
for j=1:NN
FV=0;
for l=1:NS
FV=FV+SV(l)*y_sv(l)*K(X(:,j),x_sv(:,l),si);
end
FV=FV-b;
if (y(j)*FV)>=1
FV=0;
else
if (y(j)*FV)<0
err=err+1;
FV=0;
end
end
end
%now the validation is complete.
%start the decision curve drawing;畫出決策超平面;
fprintf('Draw the judgement curve ...\n');
fprintf('------------------------------------\n');
figure;
plot(h,ve,'.');
hold on;
plot(x_sv(1,:),x_sv(2,:),'r*');
for il=1:400;
xl(il)=-1+il/200;
ul=10;
pl=2;
for jl=1:400;
yl(jl)=-1+jl/200;
Cur=[xl(il) yl(jl)]';
oo=0;
for l=1:NS
oo=oo+SV(l)*y_sv(l)*K(Cur,x_sv(:,l),si);
end
oo=oo-b;
zl=abs(oo);
if zl<ul;
ul=zl;
pl=yl(jl);
end
end
pl_l(il)=pl;
end
hold on;
plot(xl,pl_l,'r.');
fprintf('run time = %4.2f seconds\n',cputime-st); %統計時間,判斷運算效率;
%---------------------------------------------------------------------------
%主程式到此結束,一下是核函式;要放到同一個目錄下;
%---------------------------------------------------------------------------
function f=K(x,z,sigma)
%sigma=0.25;
f=exp(-1/(2*sigma)*norm(x-z)^2);
%這個函式命名為K.m,和主程式放到同一個目錄下就可以了。
%{
l=0;
for i=1:N
if y(i)*FX(i)==1
l=l+1;
end
end
%}