1. 程式人生 > >簡單粗暴理解支援向量機(SVM)及其MATLAB例項

簡單粗暴理解支援向量機(SVM)及其MATLAB例項

目錄

SVM概述

QP求解

SVM概述

SVM已經是非常流行、大家都有所耳聞的技術了。網路上也有很多相關的部落格,講解得都非常詳細。如果你要從零開始推導一個SVM,細緻摳它全程的數學原理,我建議可以閱讀此篇文章:Zhang Hao的《從零構建支援向量機》。因此本文就不做過多的枯燥的數學原理的講解。

下面我們只針對數學基礎不一、偏工程應用的同學,用簡單的語言描述來幫助理解SVM。

傳統的SVM做的事情其實就是找到一個超平面,實現二分類,一類+1,一類-1。如上所示。它的目的就是使得兩類的間隔最大。黑色的塊表示距離分割面最近的樣本向量,稱為支援向量

如果我們在低維空間裡找不到一個線性分類面把樣本分開,SVM就為我們提供了一個思路:將資料從低維空間對映到高維空間後,就很可能使得這堆資料線性可分。比如說,我們要在貓科動物這個特徵很侷限的“低維空間”裡去分貓和老虎,是比較困難的,因為他們很多特徵比較相近。但是,如果我們有了更多的參考依據,從生物界的視角,即一個“高維空間”再去區分貓和老虎,我們就有了更多的理由來做出科學的辨別。至於如何低維對映到高維,就是一門數學上的學問了。

資料從輸入到輸出的過程其實和神經網路非常像:

K就是核函式,做一個內積的運算。SVM中核函式保證了低維空間裡的計算量,輸出到高維空間裡。

K相當於隱含層的神經元。核函式的輸出乘上權重,進入啟用函式處。

SVM的改進:解決迴歸擬合問題的SVR

  • 為了利用SVM解決迴歸擬合方面的問題,Vapnik等人在SVM分 類的基礎上引入了 不敏感損失函式,從而得到了迴歸型支援向 量機(Support Vector Machine for Regression,SVR)。

  • SVM應用於迴歸擬合分析時,其基本思想不再是尋找一個最優 分類面使得兩類樣本分開,而是尋找一個最優分類面使得所有 訓練樣本離該最優分類面的誤差最小

多分類的SVM

當我們要分多類,而不是簡單的二分類(+1,-1)時,怎麼破?

解決思路:把多分類轉化為二分類問題。具體來看有兩個辦法:

1. one-against-all

      Classification of new instances for one-against-all case is done by a winner-takes-all strategy, in which the classifier

with the highest output function assigns the class.

比如有一堆樣本,打算分成10類。那麼我們先取第1類訓練標記為【1】。其他9類都是【-1】。這樣經過一次SVM就可以得到第1類。

然後我們對【-1】中的9類繼續做上述操作,分出第2類。

再以此類對,逐漸把第3、第4類分出來……直至分完。

2. one-against-one

       For the one-against-one approach, classification is done by a max-wins voting strategy, in which every classifier assigns the instance to one of the two classes, then the vote for the assigned class is increased by one vote, and finally the class with most votes determines the instance classification.

比如,一共有10種類別的一堆資料。那麼我們就要訓練C{2,5}=10(組合數)個SVM分類器。每個SVM分類器都可以區分出兩種類別。我們把資料分別輸入到這10個SVM分類器中,根據結果進行投票,依據得票數最多來確定它的類別。

QP求解

大致有下面4種方法:

分塊演算法(Chunking)

Osuna演算法

序列最小優化演算法(Sequential Minimal Optimization,SMO)

增量學習演算法(IncrementalLearning)

數學原理比較難解釋清楚,大家可以看Zhang Hao的那篇文章細究。

SVM的MATLAB實現:Libsvm

重要函式:

  • meshgrid 交叉驗證用

    • –  Generate X and Y arrays for 3-D plots

    • –  [X,Y] = meshgrid(x,y) –

  • svmtrain

    • –  Train support vector machine classifier

    • –  model = svmtrain(train_label,train_matrix,’libsvm_options’);

    • Options:可用的選項即表示的涵義如下:
        -s svm型別:SVM設定型別(預設0)
          0 -- C-SVC
          1 -- nu-SVC
          2 -- one-class SVM
          3 -- epsilon-SVR
          4 -- nu-SVR
        -t 核函式型別:核函式設定型別(預設2)
          0 -- linear: u'*v 線性
          1 -- polynomial: (gamma*u'*v + coef0)^degree   多項式
          2 -- radial basis function: exp(-gamma*|u-v|^2) RBF
          3 -- sigmoid: tanh(gamma*u'*v + coef0)
          4 -- precomputed kernel (kernel values in training_instance_matrix)
        -d degree:核函式中的degree設定(針對多項式核函式)(預設3)
        -g r(gama):核函式中的gamma函式設定(針對多項式/rbf/sigmoid核函式)(預設1/ k)
        -r coef0:核函式中的coef0設定(針對多項式/sigmoid核函式)((預設0)
        -c cost:設定C-SVC,e -SVR和v-SVR的引數(損失函式)(預設1)  懲罰因子
        -n nu:設定v-SVC,一類SVM和v- SVR的引數(預設0.5)
        -p p:設定e -SVR 中損失函式p的值(預設0.1)
        -m cachesize:設定cache記憶體大小,以MB為單位(預設40)
        -e eps:設定允許的終止判據(預設0.001)
        -h shrinking:是否使用啟發式,0或1(預設1)
        -wi weight:設定第幾類的引數C為weight*C(C-SVC中的C)(預設1)
        -v n: n-fold互動檢驗模式,n為fold的個數,必須大於等於2
        其中-g選項中的k是指輸入資料中的屬性數。option -v 隨機地將資料剖分為n部分並計算互動檢驗準確度和均方根誤差。
         以上這些引數設定可以按照SVM的型別和核函式所支援的引數進行任意組合,如果設定的引數在函式或SVM型別中沒有也不會產生影響,程式不會接受該引數;如果應有的引數設定不正確,引數將採用預設值。

  • svmpredict

    • –  Predict data using support vector machine

    • –  [predict_label,accuracy] = svmpredict(test_label,test_matrix,model);

【例項】用SVM分類

%% I. 清空環境變數
clear all
clc

%% II. 匯入資料
load BreastTissue_data.mat

%%
% 1. 隨機產生訓練集和測試集
n = randperm(size(matrix,1));

%%
% 2. 訓練集――80個樣本
train_matrix = matrix(n(1:80),:);
train_label = label(n(1:80),:);

%%
% 3. 測試集――26個樣本
test_matrix = matrix(n(81:end),:);
test_label = label(n(81:end),:);

%% III. 資料歸一化
[Train_matrix,PS] = mapminmax(train_matrix');
Train_matrix = Train_matrix';
Test_matrix = mapminmax('apply',test_matrix',PS);
Test_matrix = Test_matrix';

%% IV. SVM建立/訓練(RBF核函式)
%%
% 1. 尋找最佳c/g引數――交叉驗證方法
[c,g] = meshgrid(-10:0.2:10,-10:0.2:10);
[m,n] = size(c);
cg = zeros(m,n);
eps = 10^(-4);
v = 5;
bestc = 1;
bestg = 0.1;
bestacc = 0;
for i = 1:m
    for j = 1:n
        cmd = ['-v ',num2str(v),' -t 2',' -c ',num2str(2^c(i,j)),' -g ',num2str(2^g(i,j))];
        cg(i,j) = svmtrain(train_label,Train_matrix,cmd);     
        if cg(i,j) > bestacc
            bestacc = cg(i,j);
            bestc = 2^c(i,j);
            bestg = 2^g(i,j);
        end        
        if abs( cg(i,j)-bestacc )<=eps && bestc > 2^c(i,j) 
            bestacc = cg(i,j);
            bestc = 2^c(i,j);
            bestg = 2^g(i,j);
        end               
    end
end
cmd = [' -t 2',' -c ',num2str(bestc),' -g ',num2str(bestg)];

%%
% 2. 建立/訓練SVM模型
model = svmtrain(train_label,Train_matrix,cmd);

%% V. SVM模擬測試
[predict_label_1,accuracy_1] = svmpredict(train_label,Train_matrix,model);
[predict_label_2,accuracy_2] = svmpredict(test_label,Test_matrix,model);
result_1 = [train_label predict_label_1];
result_2 = [test_label predict_label_2];

%% VI. 繪圖
figure
plot(1:length(test_label),test_label,'r-*')
hold on
plot(1:length(test_label),predict_label_2,'b:o')
grid on
legend('真實類別','預測類別')
xlabel('測試集樣本編號')
ylabel('測試集樣本類別')
string = {'測試集SVM預測結果對比(RBF核函式)';
          ['accuracy = ' num2str(accuracy_2(1)) '%']};
title(string)

【例項】用SVM迴歸

%% I. 清空環境變數
clear all
clc

%% II. 匯入資料
load concrete_data.mat

%%
% 1. 隨機產生訓練集和測試集
n = randperm(size(attributes,2));

%%
% 2. 訓練集――80個樣本
p_train = attributes(:,n(1:80))';
t_train = strength(:,n(1:80))';

%%
% 3. 測試集――23個樣本
p_test = attributes(:,n(81:end))';
t_test = strength(:,n(81:end))';

%% III. 資料歸一化
%%
% 1. 訓練集
[pn_train,inputps] = mapminmax(p_train');
pn_train = pn_train';
pn_test = mapminmax('apply',p_test',inputps);
pn_test = pn_test';

%%
% 2. 測試集
[tn_train,outputps] = mapminmax(t_train');
tn_train = tn_train';
tn_test = mapminmax('apply',t_test',outputps);
tn_test = tn_test';

%% IV. SVM模型建立/訓練
%%
% 1. 尋找最佳c引數/g引數
[c,g] = meshgrid(-10:0.5:10,-10:0.5:10);
[m,n] = size(c);
cg = zeros(m,n);
eps = 10^(-4);
v = 5;
bestc = 0;
bestg = 0;
error = Inf;
for i = 1:m
    for j = 1:n
        cmd = ['-v ',num2str(v),' -t 2',' -c ',num2str(2^c(i,j)),' -g ',num2str(2^g(i,j) ),' -s 3 -p 0.1'];
        cg(i,j) = svmtrain(tn_train,pn_train,cmd);
        if cg(i,j) < error
            error = cg(i,j);
            bestc = 2^c(i,j);
            bestg = 2^g(i,j);
        end
        if abs(cg(i,j) - error) <= eps && bestc > 2^c(i,j)
            error = cg(i,j);
            bestc = 2^c(i,j);
            bestg = 2^g(i,j);
        end
    end
end

%%
% 2. 建立/訓練SVM  
cmd = [' -t 2',' -c ',num2str(bestc),' -g ',num2str(bestg),' -s 3 -p 0.01'];
model = svmtrain(tn_train,pn_train,cmd);

%% V. SVM模擬預測
[Predict_1,error_1] = svmpredict(tn_train,pn_train,model);
[Predict_2,error_2] = svmpredict(tn_test,pn_test,model);

%%
% 1. 反歸一化
predict_1 = mapminmax('reverse',Predict_1,outputps);
predict_2 = mapminmax('reverse',Predict_2,outputps);

%%
% 2. 結果對比
result_1 = [t_train predict_1];
result_2 = [t_test predict_2];

%% VI. 繪圖
figure(1)
plot(1:length(t_train),t_train,'r-*',1:length(t_train),predict_1,'b:o')
grid on
legend('真實值','預測值')
xlabel('樣本編號')
ylabel('耐壓強度')
string_1 = {'訓練集預測結果對比';
           ['mse = ' num2str(error_1(2)) ' R^2 = ' num2str(error_1(3))]};
title(string_1)
figure(2)
plot(1:length(t_test),t_test,'r-*',1:length(t_test),predict_2,'b:o')
grid on
legend('真實值','預測值')
xlabel('樣本編號')
ylabel('耐壓強度')
string_2 = {'測試集預測結果對比';
           ['mse = ' num2str(error_2(2)) ' R^2 = ' num2str(error_2(3))]};
title(string_2)

%% VII. BP神經網路
%%
% 1. 資料轉置
pn_train = pn_train';
tn_train = tn_train';
pn_test = pn_test';
tn_test = tn_test';

%%
% 2. 建立BP神經網路
net = newff(pn_train,tn_train,10);

%%
% 3. 設定訓練引數
net.trainParam.epochs = 1000;
net.trainParam.goal = 1e-3;
net.trainParam.show = 10;
net.trainParam.lr = 0.1;

%%
% 4. 訓練網路
net = train(net,pn_train,tn_train);

%%
% 5. 模擬測試
tn_sim = sim(net,pn_test);

%%
% 6. 均方誤差
E = mse(tn_sim - tn_test);

%%
% 7. 決定係數
N = size(t_test,1);
R2=(N*sum(tn_sim.*tn_test)-sum(tn_sim)*sum(tn_test))^2/((N*sum((tn_sim).^2)-(sum(tn_sim))^2)*(N*sum((tn_test).^2)-(sum(tn_test))^2)); 

%%
% 8. 反歸一化
t_sim = mapminmax('reverse',tn_sim,outputps);

%%
% 9. 繪圖
figure(3)
plot(1:length(t_test),t_test,'r-*',1:length(t_test),t_sim,'b:o')
grid on
legend('真實值','預測值')
xlabel('樣本編號')
ylabel('耐壓強度')
string_3 = {'測試集預測結果對比(BP神經網路)';
           ['mse = ' num2str(E) ' R^2 = ' num2str(R2)]};
title(string_3)