1. 程式人生 > >基於KNN分類演算法手寫數字識別的實現(二)——構建KD樹

基於KNN分類演算法手寫數字識別的實現(二)——構建KD樹

上一篇已經簡單粗暴的建立了一個KNN模型對手寫圖片進行了識別,所以本篇文章採用構造KD樹的方法實現手寫數字的識別。

(一)構造KD樹

構造KD樹的基本原理網上都有介紹,所以廢話不多說,直接上程式碼。

#Knn KD_Tree演算法

import math
from collections import namedtuple

#定義命名元祖,用來存放結果,最近點,最近距離和訪問過的節點數
result = namedtuple('Result_tuple', 'nearest_point nearest_dist nodes_visited')


# In[5]:


#構造KD樹

#初始化構造KD樹的元素
class KD_Node(object):
    
    def __init__(self, dom_elt, split, left, right):
        
        self.dom_elt = dom_elt #k維向量節點
        self.split = split     #整數,進行分割的序號
        self.left = left       #該節點分割超平面的左子樹
        self.right = right     #該節點分割超平面的右子樹
        
class KD_Tree(object):
    
    def __init__(self, data):
        
        k = len(data[0])       #資料的維度
        
        def Create_Node(split, data_set): #按第split維劃分資料data_set建立的KD_Node
            
            if (data_set == []):       #資料集為空

                return None
            #key引數的值為一個函式,此函式只有一個引數且返回一個值來進行比較
            #operator模組提供的itemgetter函式用來獲取物件有哪些維的資料,
            #引數為需要獲取的資料物件中的序號
            data_set = list(data_set)
            data_set.sort(key=lambda x: x[split])
            split_positon = len(data_set) // 2 #//代表整除
            median = data_set[split_positon] #中位數
            split_next = (split + 1) % k 
            #遞迴建立KD數
            return KD_Node(median, split,
                          Create_Node(split_next, data_set[:split_positon]),
                          Create_Node(split_next, data_set[split_positon + 1:]))
        
        self.root = Create_Node(0, data)
        
#KD樹的前序遍歷
def Pre_Order(root):
    
#     print(root.dom_elt)
    if (root.left):
        Pre_Order(root.left)
    if (root.right):
        Pre_Order(root.right)

KD樹構造完成後,可以計算最近鄰。

#搜尋最近鄰

def Find_Nearest(tree, point):
    
    k = len(point) #資料維度
    
    def Travel(kd_node, target, max_dist):
        
        if kd_node is None:
            
            return result([0] * k, float("inf"), 0)#inf表示正無窮,-inf表示負無窮
        
        nodes_visited = 1
        s = kd_node.split  #進行分割的維度
        pivot = kd_node.dom_elt #進行分割的軸
        
        if target[s] <= pivot[s]: #如果目標點第s維小於分割軸對應值,即目標離左子樹更近
            
            nearer_node = kd_node.left #下一個訪問的點為左子樹
            further_node = kd_node.right #同時記錄右子樹
        else:                     #目標離右子樹較近的時候
            
            nearer_node = kd_node.right #下一個訪問點為右子樹根節點
            further_node = kd_node.left #記錄左子樹
        
        temp1 = Travel(nearer_node, target, max_dist) #遍歷找到包含目標點的位置
        nearest = temp1.nearest_point #以此節點作為“當前最近點”
        dist = temp1.nearest_dist     #更新最近距離
        nodes_visited += temp1.nodes_visited
        
        if dist < max_dist:
            
            max_dist = dist #最近點將在以目標點為圓心,max_dist為半徑的超球體內
        
        temp_dist = abs(pivot[s] - target[s]) #第s維上目標點與分割超平面的距離
        
        if max_dist < temp_dist: #判斷超球體是否與分割平面相交
            
            return result(nearest, dist, nodes_visited)
            
        #計算目標點與分割點的歐氏距離
        temp_dist = math.sqrt(sum((p1 - p2) ** 2 for p1, p2 in zip(pivot, target)))
        
        if temp_dist < dist: #如果得到更近的點
            
            nearest = pivot  #更新更近的點
            dist = temp_dist #更新最近的距離
            max_dist = dist  #更新超球體半徑
        
        #檢查另一個子節點對應的區域是否有更近的點
        temp2 = Travel(further_node, target, max_dist)
        nodes_visited += temp2.nodes_visited
        
        if temp2.nearest_dist <  dist: #如果另一個子節點中存在更近的距離
            
            nearest = temp2.nearest_point #更新最近的點
            dist = temp2.nearest_dist     #更新最近距離
        
        return result(nearest, dist, nodes_visited)
    
    return Travel(tree.root, point, float("inf")) #從根節點開始遞迴

測試結果,計算[2,4.5]離資料集:[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]中最近的點。

if __name__ == "__main__":
    
    data = [[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]]
    kd = KD_Tree(data)
    rst = Find_Nearest(kd, [2,4.5])

[2,4.5]最近鄰為[2,3],最短距離為1.5.測試結果看出KD樹的效果還是不錯的。那麼在大資料高維度情況下,KD樹的測試結果怎樣呢。

(二)對比蠻力實現和KD樹實現的區別

對之前處理的1萬條樣本資料選擇8000條作為訓練集,2000條作為檢驗集。

考慮到程式碼行較多的情況,本次對比使用封裝模組,然後呼叫模組執行測試結果。

生成3個.py檔案:Sample.py、Knn.py和KD_Tree.py

此部分程式碼與前面的程式碼區別不大,就不再進行復制。如有需要可以在網頁連結中下載,提取碼: po7s。

執行檔案為Main,py

import sys
sys.path.append(r"D:/Python_work/機器學習/KNN分類演算法/Knn")

from Sample import Sample_PC
from datetime import datetime


#呼叫引數
k = 3
train_file_route = r"E:/data/digit_data_copy/train/"
test_file_route = r"E:/data/digit_data_copy/test/"
model = "KD_Tree"


#執行蠻力實現
func1 = Sample_PC(3,train_file_route, test_file_route,None)
t1 = datetime.now()
result1 = func1.test_data()
t2 = datetime.now()
print('knn耗時:', t2-t1)


#執行KD樹實現
func2 = Sample_PC(3,train_file_route, test_file_route, model=model)
t3 = datetime.now()
result2 = func2.test_data()
t4 = datetime.now()
print('KD_Tree耗時:', t4-t3)

結論:

蠻力實現:準確率:0.977,耗時:2分56秒

混淆矩陣

file_name
forecast_data 0 1 2 3 4 5 6 7 8 9
real_data
0 209 1 0 0 0 0 1 1 0 0
1 0 221 0 0 0 0 0 0 0 0
2 2 0 163 0 1 0 0 2 0 0
3 0 0 0 206 0 1 0 2 1 0
4 0 1 1 0 209 1 1 0 0 2
5 0 1 0 1 0 172 3 0 0 1
6 0 1 0 0 0 0 184 0 0 0
7 0 4 0 0 0 0 0 203 0 0
8 1 1 1 1 0 2 0 1 198 0
9 1 2 0 2 1 0 0 4 0 189

KD樹實現:準確率:0.989,耗時:1個小時53分鐘

混淆矩陣:

雖然,KD樹的準確率在蠻力實現之上,但KD樹對於高維大資料的計算大過於耗費時間,且準確率提升也不是特別高。總體而言,knn分類效果較好,但計算比較耗時,這也是它最大的一個缺點。