基於KNN分類演算法手寫數字識別的實現(二)——構建KD樹
阿新 • • 發佈:2018-12-12
上一篇已經簡單粗暴的建立了一個KNN模型對手寫圖片進行了識別,所以本篇文章採用構造KD樹的方法實現手寫數字的識別。
(一)構造KD樹
構造KD樹的基本原理網上都有介紹,所以廢話不多說,直接上程式碼。
#Knn KD_Tree演算法 import math from collections import namedtuple #定義命名元祖,用來存放結果,最近點,最近距離和訪問過的節點數 result = namedtuple('Result_tuple', 'nearest_point nearest_dist nodes_visited') # In[5]: #構造KD樹 #初始化構造KD樹的元素 class KD_Node(object): def __init__(self, dom_elt, split, left, right): self.dom_elt = dom_elt #k維向量節點 self.split = split #整數,進行分割的序號 self.left = left #該節點分割超平面的左子樹 self.right = right #該節點分割超平面的右子樹 class KD_Tree(object): def __init__(self, data): k = len(data[0]) #資料的維度 def Create_Node(split, data_set): #按第split維劃分資料data_set建立的KD_Node if (data_set == []): #資料集為空 return None #key引數的值為一個函式,此函式只有一個引數且返回一個值來進行比較 #operator模組提供的itemgetter函式用來獲取物件有哪些維的資料, #引數為需要獲取的資料物件中的序號 data_set = list(data_set) data_set.sort(key=lambda x: x[split]) split_positon = len(data_set) // 2 #//代表整除 median = data_set[split_positon] #中位數 split_next = (split + 1) % k #遞迴建立KD數 return KD_Node(median, split, Create_Node(split_next, data_set[:split_positon]), Create_Node(split_next, data_set[split_positon + 1:])) self.root = Create_Node(0, data) #KD樹的前序遍歷 def Pre_Order(root): # print(root.dom_elt) if (root.left): Pre_Order(root.left) if (root.right): Pre_Order(root.right)
KD樹構造完成後,可以計算最近鄰。
#搜尋最近鄰 def Find_Nearest(tree, point): k = len(point) #資料維度 def Travel(kd_node, target, max_dist): if kd_node is None: return result([0] * k, float("inf"), 0)#inf表示正無窮,-inf表示負無窮 nodes_visited = 1 s = kd_node.split #進行分割的維度 pivot = kd_node.dom_elt #進行分割的軸 if target[s] <= pivot[s]: #如果目標點第s維小於分割軸對應值,即目標離左子樹更近 nearer_node = kd_node.left #下一個訪問的點為左子樹 further_node = kd_node.right #同時記錄右子樹 else: #目標離右子樹較近的時候 nearer_node = kd_node.right #下一個訪問點為右子樹根節點 further_node = kd_node.left #記錄左子樹 temp1 = Travel(nearer_node, target, max_dist) #遍歷找到包含目標點的位置 nearest = temp1.nearest_point #以此節點作為“當前最近點” dist = temp1.nearest_dist #更新最近距離 nodes_visited += temp1.nodes_visited if dist < max_dist: max_dist = dist #最近點將在以目標點為圓心,max_dist為半徑的超球體內 temp_dist = abs(pivot[s] - target[s]) #第s維上目標點與分割超平面的距離 if max_dist < temp_dist: #判斷超球體是否與分割平面相交 return result(nearest, dist, nodes_visited) #計算目標點與分割點的歐氏距離 temp_dist = math.sqrt(sum((p1 - p2) ** 2 for p1, p2 in zip(pivot, target))) if temp_dist < dist: #如果得到更近的點 nearest = pivot #更新更近的點 dist = temp_dist #更新最近的距離 max_dist = dist #更新超球體半徑 #檢查另一個子節點對應的區域是否有更近的點 temp2 = Travel(further_node, target, max_dist) nodes_visited += temp2.nodes_visited if temp2.nearest_dist < dist: #如果另一個子節點中存在更近的距離 nearest = temp2.nearest_point #更新最近的點 dist = temp2.nearest_dist #更新最近距離 return result(nearest, dist, nodes_visited) return Travel(tree.root, point, float("inf")) #從根節點開始遞迴
測試結果,計算[2,4.5]離資料集:[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]中最近的點。
if __name__ == "__main__":
data = [[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]]
kd = KD_Tree(data)
rst = Find_Nearest(kd, [2,4.5])
[2,4.5]最近鄰為[2,3],最短距離為1.5.測試結果看出KD樹的效果還是不錯的。那麼在大資料高維度情況下,KD樹的測試結果怎樣呢。
(二)對比蠻力實現和KD樹實現的區別
對之前處理的1萬條樣本資料選擇8000條作為訓練集,2000條作為檢驗集。
考慮到程式碼行較多的情況,本次對比使用封裝模組,然後呼叫模組執行測試結果。
生成3個.py檔案:Sample.py、Knn.py和KD_Tree.py
此部分程式碼與前面的程式碼區別不大,就不再進行復制。如有需要可以在網頁連結中下載,提取碼: po7s。
執行檔案為Main,py
import sys
sys.path.append(r"D:/Python_work/機器學習/KNN分類演算法/Knn")
from Sample import Sample_PC
from datetime import datetime
#呼叫引數
k = 3
train_file_route = r"E:/data/digit_data_copy/train/"
test_file_route = r"E:/data/digit_data_copy/test/"
model = "KD_Tree"
#執行蠻力實現
func1 = Sample_PC(3,train_file_route, test_file_route,None)
t1 = datetime.now()
result1 = func1.test_data()
t2 = datetime.now()
print('knn耗時:', t2-t1)
#執行KD樹實現
func2 = Sample_PC(3,train_file_route, test_file_route, model=model)
t3 = datetime.now()
result2 = func2.test_data()
t4 = datetime.now()
print('KD_Tree耗時:', t4-t3)
結論:
蠻力實現:準確率:0.977,耗時:2分56秒
混淆矩陣
file_name | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|
forecast_data | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 |
real_data | ||||||||||
0 | 209 | 1 | 0 | 0 | 0 | 0 | 1 | 1 | 0 | 0 |
1 | 0 | 221 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
2 | 2 | 0 | 163 | 0 | 1 | 0 | 0 | 2 | 0 | 0 |
3 | 0 | 0 | 0 | 206 | 0 | 1 | 0 | 2 | 1 | 0 |
4 | 0 | 1 | 1 | 0 | 209 | 1 | 1 | 0 | 0 | 2 |
5 | 0 | 1 | 0 | 1 | 0 | 172 | 3 | 0 | 0 | 1 |
6 | 0 | 1 | 0 | 0 | 0 | 0 | 184 | 0 | 0 | 0 |
7 | 0 | 4 | 0 | 0 | 0 | 0 | 0 | 203 | 0 | 0 |
8 | 1 | 1 | 1 | 1 | 0 | 2 | 0 | 1 | 198 | 0 |
9 | 1 | 2 | 0 | 2 | 1 | 0 | 0 | 4 | 0 | 189 |
KD樹實現:準確率:0.989,耗時:1個小時53分鐘
混淆矩陣:
雖然,KD樹的準確率在蠻力實現之上,但KD樹對於高維大資料的計算大過於耗費時間,且準確率提升也不是特別高。總體而言,knn分類效果較好,但計算比較耗時,這也是它最大的一個缺點。