統計學習三:2.K近鄰法代碼實現(以最近鄰法為例)
阿新 • • 發佈:2018-09-12
數據集 learning pytho port 4.3 @property 存儲 uil github
通過上文可知感知機模型的基本原理,以及算法的具體流程。本文實現了感知機模型算法的原始形式,通過對算法的具體實現,我們可以對算法有進一步的了解。具體代碼可以在我的github上查看。
代碼
#!/usr/bin/python3 # -*- coding:utf-8 -*- import sys import numpy as np class Kdtree(object): ‘‘‘ 類名: Kdtree 用於存儲kd樹的數據 成員: __value: 訓練數據,保存數據點的坐標 __type: 保存點對應的類型 __dim: 保存當前kd樹節點的切分平面維度 left: 左子樹 right: 右子樹 ‘‘‘ def __init__(self, node = None, node_type = -1, dim = 0, left = None, right = None): self.__value = node self.__type = node_type self.__dim = dim self.left = left self.right = right @property def type(self): return self.__type @property def value(self): return self.__value @property def dim(self): return self.__dim def distance(self, node): ‘‘‘ 計算當前節點與傳入節點之間的距離 參數: node: 需要計算距離的節點 ‘‘‘ if node == None: return sys.maxsize dis = 0 for i in range(len(self.__value)): dis = dis + (self.__value[i] - node.__value[i]) ** 2 return dis def build_tree(self, nodes, dim = 0): ‘‘‘ 利用訓練數據建立一棵kd樹 參數: nodes: 訓練數據集 dim: 樹的切分平面維度 return: a kd-tree ‘‘‘ if len(nodes) == 0: return None elif len(nodes) == 1: self.__dim = dim self.__value = nodes[0][:-1] self.__type = nodes[0][-1] return self #將數據集按照第dim維度的值的大小進行排序 sortNodes = sorted(nodes, key = lambda x:x[dim], reverse = False) #排序後,中間的點為當前節點值 midNode = sortNodes[len(sortNodes) // 2] self.__value = midNode[:-1] self.__type = midNode[-1] self.__dim = dim leftNodes = list(filter(lambda x: x[dim] < midNode[dim], sortNodes[:len(sortNodes) // 2])) rightNodes = list(filter(lambda x: x[dim] >= midNode[dim], sortNodes[len(sortNodes) // 2 + 1:])) nextDim = (dim + 1) % (len(midNode) - 1) self.left = Kdtree().build_tree(leftNodes, nextDim) self.right = Kdtree().build_tree(rightNodes, nextDim) return self def find_type(self, fnode): ‘‘‘ 在kd樹內查找傳入點的最近鄰點和對應的類型 參數: fnode: 需要判斷類型的點 return: fnode的最近鄰點和其類型 ‘‘‘ if fnode == None: return self, -1 fNode = Kdtree(fnode) #首先搜索整棵樹到達葉子節點 path = [] currentNode = self while currentNode != None: path.append(currentNode) dim = currentNode.__dim if fNode.value[dim] < currentNode.value[dim]: currentNode = currentNode.left else: currentNode = currentNode.right #path的最後一個節點即為葉子節點 nearestNode = path[-1] nearestDist = fNode.distance(nearestNode) path = path[:-1] #向上進行回溯 while path != None and len(path) > 0: currentNode = path[-1] path = path[:-1] dim = currentNode.__dim #判斷當前點是否比最近點更近 if fNode.distance(currentNode) < nearestDist: nearestNode = currentNode nearestDist = fNode.distance(currentNode) #當前最近點一定存在於當前點的一棵子樹上,那麽找到它的兄弟子樹的節點 brotherNode = currentNode.left if fNode.value[dim] < currentNode.value[dim]: brotherNode = currentNode.right if brotherNode == None: continue #若兄弟子樹的節點對應的區域與以fnode為圓心,以nearestDist為半徑的圓相交,則進入兄弟子樹,進行遞歸查找 bdim = brotherNode.__dim if np.abs(fnode[bdim] - brotherNode.__value[bdim]) < nearestDist: cNode, _ = brotherNode.find_type(fnode) if fNode.distance(cNode) < nearestDist: nearestDist = fNode.distance(cNode) nearestNode = cNode return nearestNode, nearestNode.type if __name__ == "__main__": #訓練數據集 trainArray = [[1.0, 1.0, ‘a‘], [1.1, 1.1, ‘a‘], [1.5, 1.5, ‘a‘], [5.0, 5.0, ‘b‘], [5.2, 5.2, ‘b‘], [5.5, 5.5, ‘b‘], [3.0, 2.5, ‘c‘], [3.1, 2.8, ‘c‘], [3.2, 2.4, ‘c‘]] kdtree = Kdtree().build_tree(trainArray) #test1 testNode = [1.6, 1.5] _, testType = kdtree.find_type(testNode) print("the type of ", testNode, "is ", testType) #test2 testNode = [3.5, 2.7] _, testType = kdtree.find_type(testNode) print("the type of ", testNode, "is ", testType) #test3 testNode = [4.3, 5.1] _, testType = kdtree.find_type(testNode) print("the type of ", testNode, "is ", testType)
測試結果
通過測試結果可知,kd樹可以有效地對輸入數據進行類型的識別。
討論
雖然通過測試結果正確,但代碼依然存在許多需要改進的地方,如kd樹的選擇,可以通過改進為紅黑平衡樹,來提高搜索速度。以及對於樹的每層切分平面的維度選擇,可以選擇各維度中方差最大的維度,這樣在此維度下的點分布更加分散,使後續的查找難度更小等等。
統計學習三:2.K近鄰法代碼實現(以最近鄰法為例)