1. 程式人生 > >線性分類器之Fisher線性判別

線性分類器之Fisher線性判別

       在許多實際問題中,由於樣本特徵空間的類條件密度函式常常很難確定,利用Parzen窗等非引數方法估計分佈往往需要大量樣本,而且隨著特徵空間維數的增加所需樣本數急劇增加,因此在實際問題中,往往不去求類條件概率密度函式,而是利用樣本集直接設計分類器。具體說就是首先給定某個判別函式,然後利用樣本集確定判別函式中的未知引數。這種方法稱為判別函式法,並且根據其中判別函式的形式,可分為線性分類器和非線性分類器。線性分類器較為簡單,在計算機上容易實現,在模式識別中應用非常廣泛。在此討論線性分類器中的Fisher線性判別,應用統計方法解決很多實際問題的時候,經常會遇到維數問題。在低維空間裡解析上或者計算上可行的方法,在高維空間裡往往行不通,因此降低維數有時就成為處理實際問題的關鍵。

       可以考慮把d維空間的樣本投影到一直線上,形成一維空間,即把維數壓縮到一維,這在數學上總很容易辦到。然而即使樣本在d維空間裡形成若干緊湊的互相分得開的叢集,若把它投射到任意的一條直線上,也可能使幾類樣本混在一起而變得無法識別。但在一般情況下,總可以找到某個方向,使在這個方向的直線上,樣本的投影能分開的很好。問題是如何根據實際情況來找到這條最好的、最易於分類的投影線。這就是Fisher線性判別所需要解決的基本問題。

       對於兩類問題的Fisher線性判別的具體方法如下:

  • 計算各類樣本均值向量m_{i},N_{i}\omega _{i}類的樣本個數。

                                                                                m_{i}= \frac{1}{N^{_{i}}}\sum_{X\epsilon \omega _{i}}^{ }X, i=1,2

  • 計算樣本類內離散度矩陣S_{i}和總類內離散度矩陣S_{w}

                                                                    S{_{i}}=\sum_{X\epsilon \omega _{I}}^{ }\left ( X-m{_{i}} \right )\left ( X-m{_{i}} \right )^{T}, i= 1,2

                                                                                        S_{w}= S_{1}+S_{2}

  • 計算樣本類間離散度矩陣S_{b}.

                                                                             S{_{b}}=\left ( m{_{1}}-m{_{2}} \right )\left ( m{_{1}}-m{_{2}} \right )^{T}

  • 求向量w^{\ast }。為此定義Fisher準則函式

                                                                                     J_{F}\left ( w \right )=\frac{w^{T}S_{b}w}{w^{T}S_{w}w}

使得J_{F}\left ( w \right )取得最大值的w^{\ast }

                                                                                  w^{\ast }=S_{w}^{-1}\left ( m{_{1}}-m{_{2}} \right )

  • 將訓練資料集所有樣本進行投影。

                                                                                        y= \left ( w^{^{\ast }} \right )^{T}X

  • 計算在投影空間上的分隔閾值y_{0}。閾值的選取可以有不同的方案,較常用的一種為

                                                                                 y_{0}= \frac{N_{1}\widetilde{m_{1}}+N_{2}\widetilde{m_{2}}}{N_{1}+N_{2}}

另一種為

                                                                    y_{0}= \frac{\widetilde{m_{1}}+\widetilde{m_{2}}}{2}+ \frac{ln\left [ P\left ( w_{1} \right ) /P\left ( w_{2} \right )\right ]}{N_{1}+N_{2}-2}

其中,\widetilde{m_{i}}為在一維空間中各類樣本的均值:

                                                                                       \widetilde{m_{i}}= \frac{1}{N_{i}}\sum_{y\epsilon \omega _{i}}^{ }y.

樣本類內離散度\widetilde{s_{i}^{2}}和總類內離散度為\widetilde{s_{w}}

                                                                                      \widetilde{s_{i}^{2}}\sum_{y\epsilon \omega _{i}}^{ }\left (y- \widetilde{m_{i}}\right ),i=1,2

                                                                                                 \widetilde{s_{w}}=\widetilde{s_{1}^{2}}+\widetilde{s_{2}^{2}}      

  • 對於給定的X,計算它在w^{\ast }上的投影點y。

                                                                                                   y= \left ( w^{^{\ast }} \right )^{T}X          

  • 根據決策規則分類,有

                                                                                                \left\{\begin{matrix} y> y{_{0}}\Rightarrow X\epsilon w_{1}\\ y< y{_{0}}\Rightarrow X\epsilon w_{2} \end{matrix}\right.         

Fisher線性判別解決多類問題時,首先實現兩類Fisher分類,然後根據返回的型別與新的類別再做兩類Fisher分類,又能夠得到比較接近的類別,以此類推,直至所有的類別,最後得出未知樣本的類別。

資料集:

                           

資料集共有10000條資料,分為56維自變數,第57維為標記,兩類分別為1和2。

程式碼:

clc
clear
close all
data=load('訓練資料.mat');
type1=data.data(1:5000,1:56);
type2=data.data(5001:10000,1:56);
%類的均值向量
m1=mean(type1);
m2=mean(type2);
%各類內離散度矩陣
s1=zeros(56);
s2=zeros(56);
for i=1:1:4000
    s1=s1+(type1(i,:)-m1)'*(type1(i,:)-m1);
end
for i=1:1:4000
    s2=s2+(type2(i,:)-m2)'*(type2(i,:)-m2);
end
%總類內離散矩陣
sw=s1+s2;
%投影方向
w=((sw^-1)*(m1-m2)')';
%判別函式以及閾值T
T=-0.5*(m1+m2)*inv(sw)*(m1-m2)';

kind1=0;
kind2=0;
newtype1=[];
newtype2=[];
for i=4001:5000
    x=type1(i,:)
    g=w*x'+T;
    if g>0
        newtype1=[newtype1;x];
        kind1=kind1+1;
    else
        newtype2=[newtype2;x];
    end
end
for i=4001:5000
    x=type2(i,:)
    g=w*x'+T;
    if g>0
        newtype1=[newtype1;x];        
    else
        newtype2=[newtype2;x];
        kind2=kind2+1;
    end
end
correct=(kind1+kind2)/2000;
fprintf('\n綜合正確率:%.2f%%\n\n',correct*100);

執行結果:綜合正確率:50.85%。

 理論知識參考許國根的《模式識別與智慧計算的MATLAB實現》。可能程式碼不夠完美,歡迎大家積極探討,共同進步!