1. 程式人生 > >《李航:統計學習方法》--- K近鄰演算法(KNN)原理與簡單實現

《李航:統計學習方法》--- K近鄰演算法(KNN)原理與簡單實現

k近鄰演算法簡單,直觀:給定一個訓練資料集,對新的輸入例項,在訓練集中找到與該例項最鄰近的k個例項,這k個例項的多數屬於某個類,就把該輸入例項分為這個類。

K近鄰演算法模型

如上圖所示,藍色正方形表示一個類別,紅色三角形表示另一個類別,綠色圓圈表示待分類的樣本。按照KNN演算法,首先我們給k一個值,假設為5,那麼如圖所示,與綠色圓圈距離最近的5個樣本都在虛線圓之內,這五個樣本中數量最多的為藍色正方形所表示的類別,此時綠色圓圈的類別與藍色正方形相同。同理,假設k為3,此時實線圓之內數量最多的為紅色三角形,那麼綠色圓圈的類別就與紅色三角形的類別相同。

K近鄰模型由三個基本要素組成:距離度量,k值選擇,分類決策規則

距離度量

k近鄰模型的特徵空間一般是n維實數向量空間Rn。使用的距離是歐氏距離,但也可以是其它距離。設特徵空間Xn維實數向量空間Rnxi,xjX,xi=(x(1)i,x(2)i,...x(n)i),xj=(x(1)j,x(2)j,...x(n)j)xi,xjLp距離定義為

Lp(xi,xj)=(l=1n|x(l)ix(l)j|p)1p 這裡p1。當p=2時稱為歐氏距離,即L2(xi,xj)=(l=1n|x(l)ix(l)j|2)12p=1時,稱為曼哈頓距離,即L1(xi,xj)=l=1n|x(l)ix(l)j|p=時,它是各個座標距離的最大值.

k值選擇

k值得選擇會對k近鄰演算法的結果產生重大影響。
如果選擇的k值較小,就相當於用較小的的鄰域中的訓練例項進行預測。此時預測的結果會對近鄰的例項點非常敏感。
如果選擇較大的k值,就相當於在較大的鄰域中訓練例項進行預測。此時,與輸入例項較遠的訓練例項也會對預測起作用,使預測發生錯誤。
如果k等於訓練樣本個數,此時將輸入例項簡單的預測為訓練樣本中最多的類。這時模型過於簡單,會完全忽略訓練樣本中的大量有用資訊,是不可取的。
在應用中,k值一般選取一個比較小的數值,通常採用交叉驗證法來選取最優的k值。

分類決策規則

k近鄰演算法中分類決策規則往往是多數表決,即由輸入例項的k個鄰近的訓練例項中的多數類決定輸入例項的類。

程式碼實現

接下來,實現一個簡單的KNN演算法,資料集選用鳶尾花資料集,共150個樣本資料,選取10個作為測試集,剩下的140個作為訓練樣本。每個樣本包含5列,前4列為特徵,最後一列為樣本真實分類。

演算法步驟:

step.1—輸入要分類的資料和樣本集
step.2—計算新資料和每個訓練樣本的距離
step.3—找到與新資料最相近的K個樣本
step.4—統計K個樣本中每個類標號出現的次數
step.5—選擇出現頻率最大的類標號作為新資料的類別

C++實現

#include <iostream>
#include <fstream>
#include <sstream>
#include <cmath>

using namespace std;

const int data_dim = 4;         //資料維數
const int data_train_num = 140;   //訓練樣本資料個數
const int data_test_num = 10;   //測試樣本資料個數
const int k = 10;
const int cnum = 3;       //類別個數
const string filename_train = "iris_train.txt";    //樣本訓練資料集檔名稱
const string filename_test = "iris_test.txt";    //樣本測試資料集檔名稱


//讀取檔案
void readFile(const string filename, double datas[][data_dim], int *labels)
{
    ifstream  infile(filename);
    if(!infile)
    {
        cout << "無法開啟檔案" << endl;
        return ;
    }
    for(int i = 0; i < data_train_num;  i++)
    {
        string  str;
        getline(infile, str);
        istringstream istr(str);
        for(int j = 0;  j < data_dim;  j++)
            istr >> datas[i][j];
        istr >> labels[i];
    }
}

//計算特徵向量的歐氏距離
double calculate_distance(double *train, double *test)
{
    double result = 0;
    for(int i = 0; i < data_dim; i++)
    {
        result += pow(*train - *test, 2);
        train++;
        test++;
    }
    return pow(result, 0.5);
}

//k近鄰演算法
void knn(double d_train[][data_dim], int *l_train, double d_test[][data_dim], int *l_prediction)
{
    for(int i = 0; i < data_test_num; i++)
    {
        //用來儲存測試樣本與訓練樣本之間的歐式距離
        double eucl_distance[data_train_num]; 
        //標記是否為輸入例項的k個鄰近的訓練樣本
        bool flag[data_train_num] = {0};
        //k個鄰近訓練樣本中每個類別包含的樣本個數
        int c[cnum] = {0};
        for(int j = 0; j < data_train_num; j++)
        {
            eucl_distance[j] = calculate_distance(d_train[j], d_test[i]);
        }

        for(int m = 0; m < k; m++)
        {
            double max_distance = 1.79769e+308;
            int subs = 0;
            for(int j = 0; j < data_train_num; j++)
            {
                if(max_distance > eucl_distance[j] && flag[j] == 0)
                {
                    max_distance = eucl_distance[j];
                    subs = j;
                }
            }
            c[l_train[subs] -1]++;
            flag[subs] = 1;
        }
        int min = 0;
        int subs = 0;
        for(int j = 0; j < cnum; j++)
        {
            if(min < c[j])
            {
                min = c[j];
                subs = j;
            }
        }
        //+1是因為類別從1開始,而下標從0開始
        l_prediction[i] = subs + 1;
    }
}

int main()
{
    double      datas_train[data_train_num][data_dim] = {0};
    int         labels_train[data_train_num] = {0};
    double      datas_test[data_test_num][data_dim] = {0};
    int         labels_test[data_test_num] = {0};
    int            labes_prediction[data_test_num] = {0};

    readFile(filename_train, datas_train, labels_train);
    readFile(filename_test, datas_test, labels_test);

    knn(datas_train, labels_train, datas_test, labes_prediction);

    cout << "-----測試樣本真實分類-----" << endl;
    for(int i = 0; i < data_test_num; i++)
    {
        cout << labels_test[i] << "    ";
    }
    cout << endl << "-----測試樣本預測分類-----" << endl;
    for(int i = 0; i < data_test_num; i++)
    {
        cout << labes_prediction[i] << "    ";
    }
    cout << endl;
    system("pause");
    return 0;
}

程式碼執行結果

這裡寫圖片描述