1. 程式人生 > >支援向量機SMO演算法實現(原始碼逐條解釋)

支援向量機SMO演算法實現(原始碼逐條解釋)

支援向量機號稱機器學習中最好的演算法——存在最優解,而且一般問題都可以得解。但是演算法需要的儲存空間和計算複雜度較大,不大適合大資料量的運算,不過經過platt發明的SMO簡化運算後,效率可以提高很多。以下是筆者用Matlab語言寫的支援向量機兩分類問題的原始碼,因為在網路上得到各位前輩的指點,受益頗多,因此不敢藏私,分享如下,也歡迎有志於機器學習、神經網路、模式識別和時間序列預測的各位一起討論,共同完善,謝了先。

支援向量機的理論、原理可以參考一下連結,並藉此感謝原文作者July、pluskid ;

http://blog.csdn.net/v_july_v/article/details/7624837


這是雙月兩分類問題的分類結果,訓練樣本300個點,驗證樣本2000個點;紅色為決策超平面,紅色*點是支援向量,分類誤差5個點,誤差率0.25%,考慮到這是一種不可分模式,這樣的誤差可以接受了。

言歸正傳,原始碼如下:

%SMO algorithm solve the SVM optimization question;
%build the inner kernel at first, since the optimization algorithm is the
%easiest.
%then develop to the selection method;
%-------------------20140204----------------------------------------------
%revise the stop judgement:whether the alpha parameters and related points
%meet the KKT conditions. no obvious improvement.
%17:16 try another method: C-SVM;
%at first, C=100;
%and at the same time, I need efxt;
%-------------------------------------------------------------------------

%{

上面是對這個問題的大致說明,我的方法是由計算核心(核函式)開始,逐步往上構建整個演算法,由於運算涉及到多次使用核函式,因此如何簡化核函式、避免重複運算,成為提升效率、優化程式碼的關鍵;

當然,在這之前,我也寫下了整個程式流程圖,這讓我知道每一步的意義,也知道在程式設計的過程中,程式碼處於什麼位置,又將向何處去。

%}


clear all;                      
close all;


%關閉圖形和清除記憶體中各引數;

st = cputime;                   %record current time; 這個是為了統計整體計算時間;


%-------------------------------------------------
%    set the initial value;  
%-------------------------------------------------


N=300;                          %the training data set size;                       300個訓練點;
NN=2000;                        %the number of validation data set;     2000個驗證點;
NS=N;                           %number of SVs.                                         支援向量的數目,一開始假設所有資料點都是支援向量;


%------------------------------------------------
%produce the training set                                
%------------------------------------------------                                


    fprintf('Producing the training set...\n');
    fprintf('-----------------------------------------\n');

%運算中我們需要知道程式進行到哪一步了,這就需要在程式碼中嵌入這些指示性文字;



r=10;                               %半徑=10;
wth=6;                          %寬度=6;
dis=-6.5;                             %距離=-6;
si=0.085;

%構建雙月形狀,設定半徑、寬度、距離、以及核函式的方差,這個方差不是一蹴而就的,可以在第一次運算後,通過統計支援向量之間的最大距離和支援向量數目運算得出;

for i=1:N+NN;                       %at the beginning of every round, to produce the training set in random; 看了Yanbo Xue的程式,他是一次性產生3000個數據點,
                                    %然後用這三千個點中的一千個反覆訓練,訓練50個回合;先試試他的方法;
                                    %make the rand vectors X;
    if  rand>0.5;                   %make sure the points occur in random. and half in half for each class;
        theta=pi*rand;              %set the angle;
        lenth=r+wth*(rand-0.5);     %the range;
        h(i)=lenth*cos(theta);      %the X axis;
        ve(i)=lenth*sin(theta);     %the Y axis;
        y(i)=-1;                     %the expected response; class 1;
    else
        beta=-pi*rand;
        len_2=r+wth*(rand-0.5);
        h(i)=len_2*cos(beta)+r;
        ve(i)=len_2*sin(beta)-dis;
        y(i)=1;                   %class 2;
    end
      
end

%這裡是產生雙月形狀的程式碼,主要到數量是N+NN,也就是說訓練集和驗證集產生於同一個模型;
 
    fprintf('Normalization...\n');
    fprintf('-----------------------------------------\n');
    %the input data need to deal with;去均值和正則化?
    % and use the random data.
    %首先去掉均值;


    miu_x=mean(h);                 %取得x1的均值;
    miu_y=mean(ve);                %取得x2的均值;                
    h=h-miu_x*ones(1,NN+N);           %減去均值;相當於減去了直流分量;
    ve=ve-miu_y*ones(1,NN+N);         %減去均值;
    %正則化?高斯化?讓取值落在[0,1]區間?


    max_h=max(abs(h));             %取得絕對值的最大值;
    max_ve=max(abs(ve));        
    h=h./max_h;                    %除以這個數;
    ve=ve./max_ve;
    tr_set=[h;ve;y];               %組合成一個訓練集,響應沒有變化;
    %X=[h;ve];                      %只取資料點,不取標號,但不打亂順序,這樣可以恢復出響應來;
    seq_tr_set=randperm(N);         %重新洗牌;                     

%---------------------------------------------------------------------------------------------------------------

%注意到這裡,randperm的意思是打亂次序,也就是隨機取點,不受產生順序的限制;

%---------------------------------------------------------------------------------------------------------------
    mix_tr_set=tr_set(:,seq_tr_set);  %注意到從2300個點中取得300個;
    h=mix_tr_set(1,:);                %第一行是x軸資料,或x1;
    ve=mix_tr_set(2,:);              
    y=mix_tr_set(3,:);                %注意到響應是跟著資料點走的;
    %tr_set_1=[h;ve;y];       
    X=[h;ve];                    %某些時候只需要用到兩列,但是響應還是隨著資料點走的;
    
    plot(h,ve,'.');

%以上步驟是把產生的資料歸一化,這樣便於處理;這樣資料已經產生完畢;
    
    fprintf('Now we start the SVM by SMO algorithm ...\n');
    fprintf('-----------------------------------------\n');
    

%從這裡開始,我們準備構建SMO演算法;
    %------------------------------------------------------
    %set the SVM initial value;
    %------------------------------------------------------


    a=zeros(1,N);                   %the first initial value; 這就是著名的拉格朗日運算元,每個資料點都有一個;
    b=0;                            %bias.                                   偏置;
    C=100;                          %the boundry;                  懲罰運算元;  
    FX=zeros(1,N);                  %f(x(i)) i=1toN, since {a(i)}=0, so f(x)=0; f(x)=w'*x-b; or f(x)=SUM ai*yi*K(xi,xj)-b=0; f(xi)函式,一般有f(xi)=w'×xi-b;這裡是f(xi)=sum ai*yi*K(xi,xj)-b;
    E=FX-y;                         %y: the expected response;      這是函式f(xi)與期望響應(或者叫做標號)的差異;
                                    %??E should be followed by training set;
    efxt=a./C;                      %efxt is the gap distance related parameters;                                這個是對於那些運算元=C的點對應的間隔;
    
    eps=0.007;                      %threshold for stopping judgement;                                 閾值,(原目標-對偶目標)/(原目標+1)的比值對應停機條件,小於這個閾值,運算停止;
    times=0;                        %at the start, times=0;the number of                                   外迴圈運算多少次了;
                                    %external loops;
    presv=0;                        %at the beginning, suppose there is no                              相當於flag,表示是否找到不滿足KKT條件的支援向量;
                                    %support vector found;
    Gram=eye(N,N);                 %build the gram matrix, for recording the calculation results;  

                                                   %Gram矩陣,我們不需要運算矩陣中所有的點,但是對於已經計算了的點,我們無需重複計算;注意到採用的是eye矩陣,這樣對角資料全為1;

                                                   %這樣有兩個用意,第一,資料確實為1,第二,其他資料為0,這樣可以判斷這個位置是否運算過了;
                                   
                                   
    ot=0;                                      %指示是否運算數目太多,迭代時間太長了;                    
    totaltimes=0;                   %how many times of calculation operated?   一共執行多少次兩運算元計算了;
    in=1;                           %set the initial value of in;in start from 1;                為了區分違反KKT條件的三種運算元,有三種指示,a=0;a=C;C>a>0;
    ic=1;
    i0=1;
    
while(1)                           %the first loop, when ratio<eps, loop stop;                  開始外迴圈,採用啟發式方法選擇第一個點;
%select the first point;
%--------------------------------------------------------------------------
if (times==0)                      %first selection or calculation;                                     如果演算法初次執行,隨意選擇一個點,這裡選擇點1;
    i1=1;                          %select the point in random, so chose 1;  
   
else                               %if this is not the first selection:
    %----------------------------------------------------------------------
    %when other selection: 
    %1, choose the (0,C) alpha which break the KKT conditions;
    %2, then alpha=0 and alpha=C points which break the KKT conditions;
    %3, if alpha choosen is in the last of the squene, re-start from 1;
    %----------------------------------------------------------------------
    
    while (in<=N)                  %search in all (0,C)alpha;                                          以後的運算,首先考慮那些界內的、不滿足KKT條件的運算元;
        if (a(in)>0) && (a(in)<C)  %here, C is 100.
            if (y(in)*FX(in)>1.01)||(y(in)*FX(in)<0.99)   %meet the KKT condition or not:               這裡不能用不等於1來表示違反KKT條件,而是用一個容許範圍,取0.99~1.01;
                i1=in;             %if break the KKT rule, i1=i;  
                presv=1;           %we found a SV already;                                                                         如果滿足,記錄下當前序號,並改變flag標誌;
            end
        end        
        in=in+1;                   %if we don't found the break KKT condition SV, continue...               迴圈進行;
        if (presv)                 %we've found SV and not exceed the NS times;                                   這裡採用的是do while結構,所以採用了一個判決語句,滿足條件時退出;
            break;
        end                
    end
    
    if presv==0                    %if we don't found SV which break KKT condition, or, the times_i1 out of NS times;        

                                             %如果界內的運算元都滿足KKT條件,考慮界上的運算元(a=0或a=C)        
        in=1;                            %及時對界內運算元的序號復位,這裡復位為1; 
        while (ic<=N)
                                   %if the alpha on the boundry,i.e. a=0 or C;
            if a(ic)==C                 %這裡首先考慮a=C的運算元,因為這也是支援向量;
                if (y(ic)*FX(ic)>1.01)                        %注意到這裡也採用了容許範圍;
                    i1=ic;         %if break the KKT rule, i1=i;                    
                    presv=1;       %we found a SV already;                
                end
            end
            ic=ic+1;
            if (presv)
                break;
            end            
        end
    end
    
    if presv==0
        in=1;
        
        while (i0<=N)
                                   %if the alpha on the boundry,i.e. a=0 or C;
            if a(i0)==0     %現在考慮a=0的運算元,這些運算元是非常多的;這些就不是支援向量了;
                if (y(i0)*FX(i0)<.99)
                    i1=i0;         %if break the KKT rule, i1=i;                    
                    presv=1;       %we found a SV already;                
                end
            end
            i0=i0+1;
            if (presv)
                break;
            end           
        end
    end
      
    
    if (presv==0)                  %if we didn't found a SV which break KKT;        
        in=1;
        ic=1;
        i0=1;
        i1=floor(rand*N)+1;      %如果所有的支援向量都滿足KKT條件,那麼隨意選擇一個運算元;
    end
    
    presv=0;                       %back to the initial value;i.e.no sv found;           標誌位清零;
    
end
%--------------------------------------------------------------------------       
    
  
%{
  --------------------------------------------------------
  now I have the first point, and search the second point.

第一個運算元找到之後,開始第二個運算元;
  --------------------------------------------------------
%}
                           
max_i=1;                 
times_2=0;                            %how many times of internal loop?
i2old=i1;                             %the initial value of i2;
while(1)                              %the important question is how to
                                      %stop the loop and this is internal
                                      %loop, too.


if (times==0)                         %at the first time, most points have                初次運算,所有的運算元都要更新一下;
                                      %a(i)=0;
   if i2old<N                        %i2old from i1:N; i.e. 1:N.
        i2=i2old+1;                   %i2 from 2:N+1;
        i2old=i2;                     %remember the choice;
   end
   
else                                  %at the other loops,there should be              在之後的運算中,我們要尋找那些|E1-E2|最大的點;
                                      %some points a(i)>0;
    min_E=E(i1);                          
    max_E=E(i1);
  
    if (E(i1)>=0)
        for i=1:N       
            if (a(i)>0) && (a(i)<C)
                if (E(i)<min_E)%&&(i~=i2old_x)
                    min_E=E(i);
                    i2=i; 
                    %i2old_x=i2;
                end
            end            
        end        
    else
        for i=1:N
            if (a(i)>0) && (a(i)<C)
                if (E(i)>max_E)%&&(i~=i2old_m)
                    max_E=E(i);
                    i2=i;
                    %i2old_m=i2;
                end
            end            
        end
    end
  end
%end
    times_2=times_2+1;               %totally N for times=0; 
if times_2>=NS            %finish one complete loop;         
   break;
end
    totaltimes=totaltimes+1;         %total times of calculation;
    
    %---------------------------------------------------------------
    %now I have two points, start the calculation;
    %---------------------------------------------------------------

%now we get the i2, so we could start the optimization;
       
%select the second point;


%SMO optimization algorithm;            從現在開始就是SMO的運算了,每次更新兩個點,然後將更新後的Fx1, Fx2上傳;

%{
possible parameters;
SV(i); collect all support vectors into one group and calculate the E1&E2
using these elements;
x1,x2;
x_sv(i) and y_sv(i) compare to the SV(i); 
---------------------------------------------
maybe we don't need to collect the SV group.
we just choose the {a(i)>0}is okay,
when the algorithm stops, we could update the
SVs and give the final decision hyper plane.
---------------------------------------------
a1,a2;
a2new,a2new-unc,a2old;a1new,a1old;
y1,y2;
E1,E2;
L,H;
K11,K22,K12;
%}


%and we need the kernel function;
%{
    K(x,z)=exp(-1/(2*sigma^2)*norm(x-z)^2);
%}
%--------------------------------------------------------------------------
%
%RUN the SMO algorithm
%
%--------------------------------------------------------------------------
x1=X(:,i1);                               %這裡都比較容易看懂了,只解釋比較隱晦的部分;
x2=X(:,i2);
y1=y(i1);
y2=y(i2);
a1old=a(i1);
a2old=a(i2);
%when C=inf, the limits should be adjusted as below.
if(y1~=y2)
    L=max(0,a2old-a1old);
    H=min(C,C-a1old+a2old);
else
    L=max(0,a1old+a2old-C);
    H=min(C,a1old+a2old);
end


if times~=0 && L>=H                         %first time, we should let L=H=0;
    break;
end


K11=1;%K(x1,x1); for Guass kernel function;
K22=1;%K(x2,x2); ditto;
if Gram(i1,i2)~=0                         %注意到這裡的選擇機制,一旦Gram中的元素不等於0,就說明這兩個點已經運算過,可以跳過了;
    K12=Gram(i1,i2);
else
    K12=K(x1,x2,si);
    Gram(i1,i2)=K12;    
end


%we need two loop for E1&E2 calculation;
%l, the number of support vectors; at the first, this value is 0, then grow
%slowly.
%I need the mapping between support vectors and original points. 
%set N(i)=number of the original points, that means the pointer. such as
%N(1)=30, means the first support vector is the 30th points.
%so the pointer of original points is important.
%and I want to know the verse # of the SV, i.e.O(30)=1, means the 30th
%point in the original points, is the first SV.




E1=E(i1);                        %this value got from memory;
E2=E(i2);                        %ditto;


k=K11+K22-2*K12;                 %parameter couldn't be 0 when it worked as divider.


if k==0                          %the possibility is rare.
    k=0.01;
end
%?? how about when k=0???!!!!!----------------------
%---------------------------------------------------
a2new=a2old+y2*(E1-E2)/k;                    %這就是運算元的更新運算了;


if (a2new>H)
    a2new=H;
else
    if(a2new<L)
        a2new=L;
    end    
end


a1new=a1old+y1*y2*(a2old-a2new);
a1new=max(0,a1new);                               %運算元a1的更新;
%{
if abs(a2new-a2old)<(eps*(a2new+a2old-eps))          %if the difference is little, jump out of the loop;
    break;
end
%}
a(i1)=a1new;                           %now we could update the{a(i)}i=1toN
a(i2)=a2new;                           %update the {a(i)};
%now we starts update the bias;
%--------------------------------------------------------------------------
bold=b;
a1e=y1*(a1new-a1old);
a2e=y2*(a2new-a2old);
b1new=E1+a1e+a2e*K12+bold;                        %偏置的更新;
b2new=E2+a1e*K12+a2e+bold;


if a1new>0 && a1new<C                  %if a1new is in the bounds;
    b=b1new;
else
    if a2new>0 && a2new<C
        b=b2new;
    else 
        if (a1new==0||a1new==C)&&(a2new==0||a2new==C)&&(L~=H)
            b=(b1new+b2new)/2;
        end
    end
end
%--------------------------------------------------------------------------
%更新F(x1),F(x2);
FX1=0;
FX2=0;
for (i=1:N)
    if a(i)>0
        if Gram(i,i1)~=0;
            Ki1=Gram(i,i1);
        else
            Ki1=K(X(:,i),x1,si);
            Gram(i,i1)=Ki1;
        end
        FX1=FX1+a(i)*y(i)*Ki1;
        if Gram(i,i2)~=0;
            Ki2=Gram(i,i2);
        else
            Ki2=K(X(:,i),x2,si);
            Gram(i,i2)=Ki2;
        end
        FX2=FX2+a(i)*y(i)*Ki2;
    end
end
FX(i1)=FX1-b;                           %FX=SUM ai*yi*K(xi,xj)-b;
FX(i2)=FX2-b;
E(i1)=FX(i1)-y1;                        %store the E(i) into the E matrix;
E(i2)=FX(i2)-y2;


%N, the number of training set;


%now we calculate the stop formula and judge whether the algorithm should
%stop.
end
    %---------------------------------------------------------------
    %one internal loops complete, times++ for external loop;
    %---------------------------------------------------------------
%C=max(a)+1;
%--------------------------------------------------------------------------
% I hope that calculation could be complete less than (n-1)*(n-2)/2 times;
%--------------------------------------------------------------------------
    
    if totaltimes>100*N                                           %運算時間過長、迭代次數太多,演算法自動終止;
        fprintf('how many knives?\n');
        fprintf('-----------------------------------------\n');
        break;
    end
    
times=times+1;                        %times++, i.e. outer loop ++;  

%do we need update the E(i) and FX(i)? I think so. FX(i) is neccesary, E(i)
%don't need.
%now I hesitate.


%這裡開始記錄所有的支援向量,並更新所有的F(xi),i=1:N;
i=1;
l=0;
while i<=N
    if a(i)>0
        l=l+1;
        SV(l)=a(i);
        y_sv(l)=y(i);
        x_sv(:,l)=X(:,i);
        ptr(l)=i;                   %remember the pointer;
    end
    i=i+1;
end
NS=l;                               %the number of support vectors.
%{
if NS<=(N/10)
    p=0;
    for j=1:NS;
        for i=1:NS;
            d_max=norm(x_sv(:,i)-x_sv(:,j));
            if d_max>p;
                p=d_max;
            end
        end
    end
    d_max=p;
    si=d_max^2/(2*NS);
end
%}
%{
lold=1;
while (a(lold)==C)
    lold=lold+1;
end
FV=0;
for l=1:NS
    FV=FV+SV(l)*y_sv(l)*K(x_sv(:,l),x_sv(:,lold));
end
b=FV-y_sv(lold);
lold=lold+1;
if lold>NS
    lold=1;
end
%}
FX=zeros(1,N);
for i=1:N;
    for l=1:NS
        if Gram(ptr(l),i)~=0;
            Kli=Gram(ptr(l),i);
        else
            Kli=K(x_sv(:,l),X(:,i),si);
            Gram(ptr(l),i)=Kli;
        end
        FX(i)=FX(i)+SV(l)*y_sv(l)*Kli;                          
    end
end
sv_x=x_sv;                         %this vector for demostation.
x_sv=zeros(2,NS);                  %clear past data for preparation.
FX=FX-b*ones(1,N);
E=FX-y;


for i=1:N    
    efxt(i)=max(0,1-y(i)*FX(i));
end


%-------------------------------------
%Now we revise the stop criteria.這裡是關於停機條件的運算;
%-------------------------------------


W=0;
AAYYK=0;
for i=1:N
    if (a(i)>0)
        for j=1:N
            if (a(j)>0)
                if Gram(i,j)~=0
                    Kij=Gram(i,j);
                else
                    Kij=K(X(:,i),X(:,j),si);
                    Gram(i,j)=Kij;
                end
                AAYYK=AAYYK+a(i)*a(j)*y(i)*y(j)*Kij;
            end
        end
    end
end


suma=sum(a);
sume=C*sum(efxt);
W=suma-0.5*AAYYK;


ratio=((suma-2*W+sume)/(suma-W+sume+1));             %比值接近於0才行的;


fprintf('Now we can see the difference between original objective and dual one...\n');
fprintf('Ratio = %f\n',ratio);


if (ratio<eps)
    break;                          %if ratio meet the stop threshold, loop
                                    %stops.
end


end                                 %this end compare to while(1);


%now the SV calculation is complete.
%and there should be some SVs occured, now I can collect the SV group, 
%check the validation data set and draw the decision hyper plane.

%運算結束後,開始記錄支援向量;

i=1;
l=0;
while i<=N
    if a(i)>0
        l=l+1;
        SV(l)=a(i);
        y_sv(l)=y(i);
        x_sv(:,l)=X(:,i);   
        ptr(l)=i;
    end
    i=i+1;
end
NS=l;                               %the number of support vectors.


%now we should update the value of b.!!!!!!更新偏置;


lold=floor(rand*NS)+1;
while (a(lold)==C)
    lold=lold+1;
end
FV=0;
for l=1:NS
    if Gram(ptr(l),ptr(lold))~=0
        Kllo=Gram(ptr(l),ptr(lold));
    else
        Kllo=K(x_sv(:,l),x_sv(:,lold),si);
        Gram(ptr(l),ptr(lold))=Kllo;
    end
    FV=FV+SV(l)*y_sv(l)*Kllo;
end
b=FV-y_sv(lold);


%now I collect all support vectors and related expected response, input;


%now I can check the validation set.驗證;


% this module for producing the validation set.
%------------------------------------------------------------------
    seq_tr_set=randperm(NN);         %重新洗牌;
    mix_tr_set=tr_set(:,seq_tr_set);  %注意到從3000個點中取得1000個;
    h=mix_tr_set(1,:);                %第一行是x軸資料,或x1;
    ve=mix_tr_set(2,:);              
    y=mix_tr_set(3,:);                %注意到響應是跟著資料點走的;
    tr_set=[h;ve;y];       
    X=[h;ve];                    %某些時候只需要用到兩列,但是響應還是隨著資料點走的;
%------------------------------------------------------------------
% and do some normlization for this set.


err=0;


for j=1:NN
    FV=0;
    for l=1:NS        
        FV=FV+SV(l)*y_sv(l)*K(X(:,j),x_sv(:,l),si);
    end
    FV=FV-b;
    if (y(j)*FV)>=1
        FV=0;
    else
        if (y(j)*FV)<0
            err=err+1;
            FV=0;
        end
    end
end


%now the validation is complete.


%start the decision curve drawing;畫出決策超平面;


fprintf('Draw the judgement curve ...\n');
fprintf('------------------------------------\n');


figure;
plot(h,ve,'.');
hold on;
plot(x_sv(1,:),x_sv(2,:),'r*');


for il=1:400;    
    xl(il)=-1+il/200; 
    ul=10;
    pl=2;
    for jl=1:400;
        yl(jl)=-1+jl/200;
        Cur=[xl(il) yl(jl)]';
        oo=0;
        for l=1:NS
            oo=oo+SV(l)*y_sv(l)*K(Cur,x_sv(:,l),si);
        end          
        oo=oo-b;
        
        zl=abs(oo);
        if  zl<ul;
            ul=zl;
            pl=yl(jl);
        end
    end
    pl_l(il)=pl;
end
hold on;
plot(xl,pl_l,'r.');


fprintf('run time = %4.2f seconds\n',cputime-st); %統計時間,判斷運算效率;

%---------------------------------------------------------------------------

%主程式到此結束,一下是核函式;要放到同一個目錄下;

%---------------------------------------------------------------------------

function f=K(x,z,sigma)
%sigma=0.25;
f=exp(-1/(2*sigma)*norm(x-z)^2);

%這個函式命名為K.m,和主程式放到同一個目錄下就可以了。
        








































%{
l=0;
for i=1:N
    if y(i)*FX(i)==1
        l=l+1;
    end
end
%}