KNN(K-Nearest Neighbor)演算法Matlab實現
KNN(K-Nearest Neighbor)演算法即K最鄰近演算法,是實現分類器中比較簡單易懂的一種分類演算法。K臨近之所以簡單是因為它比較符合人們直觀感受,即人們在觀察事物,對事物進行分類的時候,人們最容易想到的就是誰離那一類最近誰就屬於哪一類,即俗話常說的“近朱者赤,近墨者黑”,人們自然而然地把這種觀察方式延伸到資料分類處理領域。K-NN演算法就是基於歐幾里得距離推斷事物類別的一種實現方法。
KNN演算法中K指的是尋找距離測試樣本最近的訓練集元素個數。
一、KNN演算法描述如下:
1、初始化訓練集和類別;
2、計算測試集樣本與訓練集樣本的歐氏距離;
3、根據歐氏距離大小對訓練集樣本進行升序排序;
4、選取歐式距離最小的前K個訓練樣本,統計其在各類別中的頻率;
5、返回頻率最大的類別,即測試集樣本屬於該類別。
二、Matlab程式碼實現
在理解了演算法步驟之後,就能很容易地實現演算法,Matlab作為一種研究資料的優秀工具,是實現和理解這個演算法不錯的選擇。
三、測試結果%實現KNN演算法 %%演算法描述 %1、初始化訓練集和類別; %2、計算測試集樣本與訓練集樣本的歐氏距離; %3、根據歐氏距離大小對訓練集樣本進行升序排序; %4、選取歐式距離最小的前K個訓練樣本,統計其在各類別中的頻率; %5、返回頻率最大的類別,即測試集樣本屬於該類別。 close all; clc; %%演算法實現 %step1、初始化訓練集、測試集、K值 %建立一個三維矩陣,二維表示同一類下的二維座標點,第三維表示類別 trainData1=[0 0;0.1 0.3;0.2 0.1;0.2 0.2];%第一類訓練資料 trainData2=[1 0;1.1 0.3;1.2 0.1;1.2 0.2];%第二類訓練資料 trainData3=[0 1;0.1 1.3;0.2 1.1;0.2 1.2];%第三類訓練資料 trainData(:,:,1)=trainData1;%設定第一類測試資料 trainData(:,:,2)=trainData2;%設定第二類測試資料 trainData(:,:,3)=trainData3;%設定第三類測試資料 trainDim=size(trainData);%獲取訓練集的維數 testData=[1.6 0.3];%設定1個測試點 K=7; %%分別計算測試集中各個點與每個訓練集中的點的歐氏距離 %把測試點擴充套件成矩陣 testData_rep=repmat(testData,4,1); %設定三個二維矩陣存放測試集與測試點的擴充套件矩陣的差值平方 %diff1=zero(trainDim(1),trianDim(2)); %diff2=zero(trainDim(1),trianDim(2)); %diff3=zero(trainDim(1),trianDim(2)); for i=1:trainDim(3) diff1=(trainData(:,:,1)-testData_rep).^2; diff2=(trainData(:,:,2)-testData_rep).^2; diff3=(trainData(:,:,3)-testData_rep).^2; end %設定三個一維陣列存放歐式距離 distance1=(diff1(:,1)+diff1(:,2)).^0.5; distance2=(diff2(:,1)+diff2(:,2)).^0.5; distance3=(diff3(:,1)+diff3(:,2)).^0.5; %將三個一維數組合成一個二維矩陣 temp=[distance1 distance2 distance3]; %將這個二維矩陣轉換為一維陣列 distance=reshape(temp,1,3*4); %對距離進行排序 distance_sort=sort(distance); %用一個迴圈尋找最小的K個距離裡面那個類裡出現的頻率最高,並返回該類 num1=0;%第一類出現的次數 num2=0;%第二類出現的次數 num3=0;%第三類出現的次數 sum=0;%sum1,sum2,sum3的和 for i=1:K for j=1:4 if distance1(j)==distance_sort(i) num1=num1+1; end if distance2(j)==distance_sort(i) num2=num2+1; end if distance3(j)==distance_sort(i) num3=num3+1; end end sum=num1+num2+num3; if sum>=K break; end end class=[num1 num2 num3]; classname=find(class(1,:)==max(class)); fprintf('測試點(%f %f)屬於第%d類',testData(1),testData(2),classname); %%使用繪圖將訓練集點和測試集點繪畫出來 figure(1); hold on; for i=1:4 plot(trainData1(i,1),trainData1(i,2),'*'); plot(trainData2(i,1),trainData2(i,2),'o'); plot(trainData3(i,1),trainData3(i,2),'>'); end plot(testData(1),testData(2),'x'); text(0.1,0.1,'第一類'); text(1.1,0.1,'第二類'); text(0.1,1,'第三類');
對於測試點testData=[1.6 0.3],測試結果如下:
圖中,星形表示第一類訓練樣本點,圈兒代表第二類訓練樣本點,三角形代表第三類訓練樣本點,x代表測試樣本點;
測試點(1.600000 0.300000)屬於第2類>>
將測試點改為:testData=[0.1 0.4],執行結果如下:
測試點(0.100000 0.400000)屬於第1類>>
將測試點改為:testData=[0.3 1.2],測試結果如下:
測試點(0.300000 1.200000)屬於第3類>>
經過測試,可以看出KNN能夠很好的實現樣本分類功能,並且理解簡單容易,對於初學分類器的開發者而言,KNN是一個很好的選擇!