統計學習方法c++實現之二 k近鄰法
統計學習方法c++實現之二 k近鄰演算法
前言
k近鄰演算法可以說概念上很簡單,即:“給定一個訓練資料集,對新的輸入例項,在訓練資料集中找到與這個例項最鄰近的k個例項,這k個例項的多數屬於某個類,就把該輸入分為這個類。”其中我認為距離度量最關鍵,但是距離度量的方法也很簡單,最長用的就是歐氏距離,其他的距離度量準則實際上就是不同的向量範數,這部分我就不贅述了,畢竟這系列部落格的重點是實現。程式碼地址:https://github.com/bBobxx/statistical-learning
kd樹
k近鄰演算法的思想很簡單,然而,再簡單的概念如果碰上高維度加上海量資料,就變得很麻煩,如果按照常規思想,將每個測試樣本和訓練樣本的距離算出來,在進行排序查詢,無疑效率十分低下,這也就是為什麼要介紹kd樹的原因。kd樹是一種二叉樹,kd樹的每個結點對應一個k維超矩形區域。
kd樹構建程式碼
每一次分割都需要確定一個軸和一個值,然後分割時只看該軸的資料,小於等於分割值就放到該結點的左子樹裡,大於分割值就放到右子樹中。那麼每個結點裡面需要儲存哪些內容呢?
我的實現裡面,每個結點有如下內容:
struct KdtreeNode { vector<double> val;//n維特徵 int cls;//類別 unsigned long axis;//分割軸 double splitVal;//分割的值 vector<vector<double>> leftTreeVal;//左子樹的值集合 vector<vector<double>> rightTreeVal;//右子樹的值集合 KdtreeNode* parent;//父節點 KdtreeNode* left;//左子節點 KdtreeNode* right;//右子節點 KdtreeNode(): cls(0), axis(0), splitVal(0.0), parent(nullptr), left(nullptr), right(nullptr){}; };
用kd樹實現的k近鄰演算法(還有其它的方法),訓練過程實際上就是樹的建造過程,我們用遞迴建立kd樹。
首先,我們需要建立並存儲根節點
KdtreeNode* root = new KdtreeNode();//類中用這個儲存根節點 void Knn::setRoot() {//這是建立根節點的程式,主要是設定左右子樹,還有分割軸,分割值 if(axisVec.empty()){ cout<<"please run createSplitAxis first."<<endl; throw axisVec.empty(); } auto axisv = axisVec; auto axis = axisv.top(); axisv.pop(); std::sort(trainData.begin(), trainData.end(), [&axis](vector<double> &left, vector<double > &right) { return left[axis]<right[axis]; }); unsigned long mid = trainData.size()/2; for(unsigned long i = 0; i < trainData.size(); ++i){ if(i!=mid){ if (i<mid) root->leftTreeVal.push_back(trainData[i]); else root->rightTreeVal.push_back(trainData[i]); } else{ root->val.assign(trainData[i].begin(),trainData[i].end()-1); root->splitVal = trainData[i][axis]; root->axis = axis; root->cls = *(trainData[i].end()-1); } } cout<<"root node set over"<<endl; }
上面的程式建立了根節點,但是分割軸是怎麼確定?當然可以依次選軸作為分割軸,但是這裡我們選擇按方差從大到小的順序選軸
stack<unsigned long> axisVec;//用棧儲存分割軸,棧頂軸方差最大。
void Knn::createSplitAxis(){//axisVec建立程式碼
cout<<"createSplitAxis..."<<endl;
//the last element of trainData is gt
vector<pair<unsigned long, double>> varianceVec;
auto sumv = trainData[0];
for(unsigned long i=1;i<trainData.size();++i){
sumv = sumv + trainData[i];
}
auto meanv = sumv/trainData.size();
vector<decltype(trainData[0]-meanv)> subMean;
for(const auto& c:trainData)
subMean.push_back(c-meanv);
for (unsigned long i = 0; i < trainData.size(); ++i) {
for (unsigned long j = 0; j < indim; ++j) {
subMean[i][j] *= subMean[i][j];
}
}
auto varc = subMean[0];
for(unsigned long i=1;i<subMean.size();++i){
varc = varc + subMean[i];
}
auto var = varc/subMean.size();
for(unsigned long i=0;i<var.size()-1;++i){//here not contain the axis of gt
varianceVec.push_back(pair<unsigned long, double>(i, var[i]));
}
std::sort(varianceVec.begin(), varianceVec.end(), [](pair<unsigned long, double> &left, pair<unsigned long, double> &right) {
return left.second < right.second;
});
for(const auto& variance:varianceVec){
axisVec.push(variance.first);//the maximum variance is on the top
}
cout<<"createSplitAxis over"<<endl;
}
現在要給根節點新增左右子樹:
root->left = buildTree(root, root->leftTreeVal, axisVec);
root->right = buildTree(root, root->rightTreeVal, axisVec);
來看一下buildTree程式碼:
KdtreeNode* Knn::buildTree(KdtreeNode*root, vector<vector<double>>& data, stack<unsigned long>& axisStack) {//第一個引數是父節點,第二個引數是目前沒有被分割的資料集合,第三個引數是當前的軸棧,
//由於後面要保證左右子樹的分割用的同一個軸,所以這裡要傳入。
stack<unsigned long> aS;
if(axisStack.empty())
aS=axisVec;
else
aS=axisStack;
auto node = new KdtreeNode();
node->parent = root;
auto axis2 = aS.top();
aS.pop();
std::sort(data.begin(), data.end(), [&axis2](vector<double> &left, vector<double > &right) {
return left[axis2]<right[axis2];
});//這裡用的c++11裡面的lambda函式
unsigned long mid = data.size()/2;
if(node->leftTreeVal.empty()&&node->rightTreeVal.empty()){
for(unsigned long i = 0; i < data.size(); ++i){
if(i!=mid){
if (i<mid)
node->leftTreeVal.push_back(data[i]);
else
node->rightTreeVal.push_back(data[i]);
} else{
node->val.assign(data[i].begin(),data[i].end()-1);
node->splitVal = data[i][axis2];
node->axis = axis2;
node->cls = *(data[i].end()-1);
}
}
}
if(!node->leftTreeVal.empty()){
node->left = buildTree(node, node->leftTreeVal, aS);//遞迴建立子樹
}
if(!node->rightTreeVal.empty()){
node->right = buildTree(node, node->rightTreeVal, aS);
}
return node;
}
建立好子樹後可以通過showTree函式前序遍歷樹來檢視,這裡就不演示了,程式碼中有這一步。
查詢K近鄰
對於用kd樹實現的Knn演算法來說,預測的過程就是查詢的過程,這裡我們給出查詢K個最近鄰的程式碼,中間用到了STL標準模板庫的priority_queue和pair的組合,用priority_queue實現大頂堆,對於由pair構成的priority_queue來說,預設的比較值是first,也就是說裡面的元素會根據pair的第一個元素從大到小排序,即用.top()得到的是最大值(預設比較函式的情況下)。在搜尋 K-近鄰時,我們可以設定一個最多有 k
個元素的大頂堆,這樣,在搜尋時,當堆滿時,只需比較當前搜尋點的 dis
是否小於堆頂點的 dis
,如果小於,堆頂出堆,並將當前搜尋點壓入,反之,則不變;當堆未滿時,直接將該搜尋點壓入。
priority_queue<pair<double, KdtreeNode*>> maxHeap;
下面給出查詢程式碼
void Knn::findKNearest(vector<double>& testD){
cout<<"the test data is(the last is class) ";
for(const auto& c:testD)
cout<<c<<" ";
cout<<"\nsearching "<<K<<" nearest val..."<<endl;
stack<KdtreeNode*> path;//這是查詢路徑
auto curNode = root;
while(curNode!= nullptr){//這個迴圈是為了初始化查詢路徑
path.push(curNode);
if(testD[curNode->axis]<=curNode->splitVal)
curNode = curNode->left;
else
curNode = curNode->right;
}
while(!path.empty()){
auto curN = path.top();
path.pop();
vector<double> testDF(testD.begin(),testD.end()-1);
double dis=0.0;
dis = computeDis(testDF, curN->val);
if(maxHeap.size()<K){
maxHeap.push(pair<double, KdtreeNode*>(dis, curN));
}
else{
if(dis<maxHeap.top().first){
maxHeap.pop();
maxHeap.push(pair<double, KdtreeNode*>(dis, curN));
}
}
if(path.empty())
continue;
auto curNparent = path.top();
KdtreeNode* curNchild;
if(testDF[curNparent->axis]<=curNparent->splitVal)//從這裡開始是為了查詢同一個父節點的
//另一個子樹中是否有比當前K個最近鄰更近的結點
curNchild = curNparent->right;//這裡和上面相反,剛好是另一個子樹。
else
curNchild = curNparent->left;
if(curNchild == nullptr)
continue;
double childDis = computeDis(testDF, curNchild->val);
if(childDis<maxHeap.top().first){//比較另一個子樹的根節點是不是比當前k個結點距離查詢點更近,
//如果是,將對應的子樹加入搜尋路徑
maxHeap.pop();
maxHeap.push(pair<double, KdtreeNode*>(childDis, curNchild));
while(curNchild!= nullptr){//add subtree to path
path.push(curNchild);
if(testD[curNchild->axis]<=curNchild->splitVal)
curNchild = curNchild->left;
else
curNchild = curNchild->right;
}
}
}
}
double Knn::computeDis(const vector<double>& v1, const vector<double>& v2){
auto v = v1 - v2;
double di = v*v;//這裡用到了基類中的操作符過載
return di;
}
總結
k近鄰演算法雖然概念簡單,但是實現由於要用到樹結構,編寫起來還是聽費事的,這裡還有更好的實現,演算法方面這篇講解的也很詳細我的查詢程式碼也是通過參考這篇寫的,感謝他們無私的貢獻。