1. 程式人生 > >機器學習:KNN演算法(MATLAB實現)

機器學習:KNN演算法(MATLAB實現)

   K-近鄰演算法的思想如下:首先,計算新樣本與訓練樣本之間的距離,找到距離最近的K 個鄰居;然後,根據這些鄰居所屬的類別來判定新樣本的類別,如果它們都屬於同一個類別,那麼新樣本也屬於這個類;否則,對每個後選類別進行評分,按照某種規則確定新樣本的類別。(統計出現的頻率)

該演算法比較適用於樣本容量比較大的類域的自動分類,而那些樣本容量較小的類域採用這種演算法比較容易產生誤分當K值較小時可能產生過擬合,因為訓練誤差很小,但是測試誤差可能很大;相反,當K值較大時可能產生欠擬合。

演算法虛擬碼

對未知類別屬性的資料集中的每個點依次執行以下操作:

(1)    計算已知類別的資料集中的點與當前點之間的距離;

(2)    按照距離遞增次序排序;

(3)    選取與當前點距離最小的K個點;

(4)    確定前K個點所在類別的出現頻率;

(5)    返回前K個點出現頻率最高的類別作為當前點的預測分類。

  1. %  
  2. %手寫數字識別系統的測試程式碼  
  3. %  
  4. function handWritingTest()  
  5.     tic; %開始計時  
  6.     K = 3;  % 這裡可以調整k值  
  7.     trainLabels = [];  
  8.     direct = mfilename('fullpath');%  
  9.     traindirect = strrep(direct,'handWritingTest','trainingDigits'); %trainingDigits  
  10.    %獲得路徑  
  11.     traindirfile = dir(fullfile(traindirect,'*.txt'));%提取字尾名.txt  
  12.     traindircell = struct2cell(traindirfile)';  
  13.     trainfilenames = traindircell(:,1);  
  14.     trainfileNums = length(trainfilenames);  
  15.     trainMat = zeros(trainfileNums,1024);  
  16.     for i = 1:trainfileNums  
  17.         fileNameStr = trainfilenames(i);  
  18.         str = deblank(fileNameStr);  
  19.         s = regexp(str,'\.','split'); %  
  20.         fileStr = s{1}(1);  
  21.         classNumStr =  regexp(fileStr,'\_','split');  
  22.         trainLabels(i)=str2num(char(classNumStr{1}(1))); %得到類別 0 - 9   
  23.         filePath = strcat(traindirect,'\',fileNameStr); %檔案路徑  
  24.         trainMat(i,:) = img2vector(filePath);%處理檔案 獲得向量  
  25.     end  
  26.     %測試樣本  
  27.     direct = mfilename('fullpath');  
  28.     testdirect = strrep(direct,'handWritingTest','testDigits');%testDigits  
  29.     testdirfile = dir(fullfile(testdirect,'*.txt'));  
  30.     testdircell = struct2cell(testdirfile)';  
  31.     testfilenames = testdircell(:,1);  
  32.     testfileNums = length(testfilenames);  
  33.     errorcount = 0;  
  34.     for j = 1:testfileNums  
  35.         fileNameStr = testfilenames(j);  
  36.         str = deblank(fileNameStr);  
  37.         s = regexp(str,'\.','split');  
  38.         fileStr = s{1}(1);  
  39.         classNumStr =  regexp(fileStr,'\_','split');  
  40.         testLabel = str2num(char(classNumStr{1}(1))); %得到類別 0 - 9   
  41.         filePath = strcat(testdirect,'\',fileNameStr);  
  42.         testVector = img2vector(filePath);  
  43.         classifyRet = classify(testVector,trainMat,trainLabels,K);  
  44.         if(classifyRet ~= testLabel)  
  45.             errorcount = errorcount + 1;  
  46.             fprintf('test result:  %d,    real result:  %d ,    here error!!! \n',classifyRet,testLabel);  
  47.         else  
  48.             fprintf('test result:  %d,    real result:  %d \n',classifyRet,testLabel);  
  49.         end  
  50.     end  
  51.     lastTime = num2str(toc);  
  52.     fprintf('\n the sum numbers of errors :  %d ',errorcount);  
  53.     fprintf('\n the total error rate :  %f  ' ,(errorcount / testfileNums));  
  54.     fprintf('\n total time :    %f',lastTime);  
  55. end  
  56. %  
  57. %KNN演算法 classify(test,dataSet,labels,k)  
  58. %四個引數:test用於分類的輸入向量;輸入的訓練樣本集為dataSet;  
  59. %標籤向量為labels; k 表示用於選擇最近鄰居的數目;  
  60. %  
  61. function maxClass = classify(test,dataSet,labels,k)  
  62.     [dataRow,dataCol] = size(dataSet);%dataRow:樣本個數;dataCol:特徵  
  63.     %求距離 test 與樣本資料之間的距離   這裡為歐式距離  
  64.     diffMat = dataSet;  
  65.     for i = 1:dataRow  
  66.         diffMat(i,:) = diffMat(i,:) - test;   
  67.     end  
  68.     sqdiffMat = diffMat.^2;  
  69.     sqDistances = sum(sqdiffMat,2).^(0.5);  
  70.     [p,q] = sort(sqDistances);  %p代表要排序的數,q代表要排序的數原來對應的索引  
  71.     %通過k  來求最鄰居的前k 個數據,然後找的在這些資料中類別最多的  
  72.     classCount=zeros(10,1);  
  73.     class = [];  
  74.     for j = 1:k  
  75.         tempLabel = labels(q(j));  
  76.         class(j) = tempLabel;%沒用到  
  77.         classCount(tempLabel+1) = classCount(tempLabel+1)+1;  
  78.     end  
  79.     [r,s] = max(classCount);  
  80.     maxClass = s - 1;  %返回 相似個數最多的 那個類  
  81. end  
  82. %  
  83. %將32*32的二進位制圖形矩陣轉換為1*1024的向量  
  84. %  
  85. function retVector = img2vector(fileName)  
  86.    fileName = char(fileName);  
  87.    tempVector = [];  
  88.     % 讀檔案  
  89.    fileData = textread(fileName,'%s');  
  90.    fileData = char(fileData);%讀取檔案,並將檔案轉換矩陣的格式  
  91.    temp = fileData(:)';  
  92.    for i = 1 : length(temp)  
  93.        tempVector(i) = str2num(temp(i));  
  94.    end  
  95.    retVector = tempVector;  
  96. end