1. 程式人生 > >統計學習三:2.K近鄰法代碼實現(以最近鄰法為例)

統計學習三:2.K近鄰法代碼實現(以最近鄰法為例)

數據集 learning pytho port 4.3 @property 存儲 uil github

通過上文可知感知機模型的基本原理,以及算法的具體流程。本文實現了感知機模型算法的原始形式,通過對算法的具體實現,我們可以對算法有進一步的了解。具體代碼可以在我的github上查看。

代碼

#!/usr/bin/python3
# -*- coding:utf-8 -*-

import sys 
import numpy as np

class Kdtree(object):
    ‘‘‘ 
    類名: Kdtree
    用於存儲kd樹的數據
    成員:
    __value: 訓練數據,保存數據點的坐標
     __type: 保存點對應的類型
      __dim: 保存當前kd樹節點的切分平面維度
       left: 左子樹
      right: 右子樹
    ‘‘‘
    def __init__(self, node = None, node_type = -1, dim = 0, left = None, right = None):
        self.__value = node
        self.__type  = node_type
        self.__dim   = dim 
        self.left    = left
        self.right   = right

    @property
    def type(self):
        return self.__type
        
    @property
    def value(self):
        return self.__value

    @property
    def dim(self):
        return self.__dim

    def distance(self, node):
        ‘‘‘ 
        計算當前節點與傳入節點之間的距離
        參數: 
        node: 需要計算距離的節點
        ‘‘‘
        if node == None:
            return sys.maxsize

        dis = 0 
        for i in range(len(self.__value)):
            dis = dis + (self.__value[i] - node.__value[i]) ** 2
        return dis
        
    def build_tree(self, nodes, dim = 0):
        ‘‘‘
        利用訓練數據建立一棵kd樹
        參數: nodes: 訓練數據集
                dim: 樹的切分平面維度
        return: a kd-tree
        ‘‘‘
        if len(nodes) == 0:
            return None
        elif len(nodes) == 1:
            self.__dim  = dim
            self.__value = nodes[0][:-1]
            self.__type  = nodes[0][-1]
            return self

        #將數據集按照第dim維度的值的大小進行排序
        sortNodes = sorted(nodes, key = lambda x:x[dim], reverse = False)

        #排序後,中間的點為當前節點值
        midNode      = sortNodes[len(sortNodes) // 2]
        self.__value = midNode[:-1]
        self.__type  = midNode[-1]
        self.__dim   = dim

        leftNodes  = list(filter(lambda x: x[dim] < midNode[dim], sortNodes[:len(sortNodes) // 2]))
        rightNodes = list(filter(lambda x: x[dim] >= midNode[dim], sortNodes[len(sortNodes) // 2 + 1:]))
        nextDim    = (dim + 1) % (len(midNode) - 1)

        self.left  = Kdtree().build_tree(leftNodes, nextDim)
        self.right = Kdtree().build_tree(rightNodes, nextDim)

        return self

    def find_type(self, fnode):
        ‘‘‘
        在kd樹內查找傳入點的最近鄰點和對應的類型
        參數: fnode: 需要判斷類型的點
        return: fnode的最近鄰點和其類型
        ‘‘‘
        if fnode == None:
            return self, -1

        fNode = Kdtree(fnode)

        #首先搜索整棵樹到達葉子節點
        path = []
        currentNode = self
        while currentNode != None:
            path.append(currentNode)

            dim   = currentNode.__dim
            if fNode.value[dim] < currentNode.value[dim]:
                currentNode = currentNode.left
            else:
                currentNode = currentNode.right

        #path的最後一個節點即為葉子節點
        nearestNode = path[-1]
        nearestDist = fNode.distance(nearestNode)
        path = path[:-1]

        #向上進行回溯
        while path != None and len(path) > 0:
            currentNode = path[-1]
            path = path[:-1]
            dim  = currentNode.__dim
            
            #判斷當前點是否比最近點更近
            if fNode.distance(currentNode) < nearestDist:
                nearestNode = currentNode
                nearestDist = fNode.distance(currentNode)

            #當前最近點一定存在於當前點的一棵子樹上,那麽找到它的兄弟子樹的節點
            brotherNode = currentNode.left
            if fNode.value[dim] < currentNode.value[dim]:
                brotherNode = currentNode.right

            if brotherNode == None:
                continue

            #若兄弟子樹的節點對應的區域與以fnode為圓心,以nearestDist為半徑的圓相交,則進入兄弟子樹,進行遞歸查找
            bdim = brotherNode.__dim
            if np.abs(fnode[bdim] - brotherNode.__value[bdim]) < nearestDist:
                cNode, _ = brotherNode.find_type(fnode)
                if fNode.distance(cNode) < nearestDist:
                    nearestDist = fNode.distance(cNode)
                    nearestNode = cNode

        return nearestNode, nearestNode.type

if __name__ == "__main__":

   #訓練數據集
   trainArray = [[1.0, 1.0, ‘a‘], [1.1, 1.1, ‘a‘], [1.5, 1.5, ‘a‘],            [5.0, 5.0, ‘b‘], [5.2, 5.2, ‘b‘], [5.5, 5.5, ‘b‘],            [3.0, 2.5, ‘c‘], [3.1, 2.8, ‘c‘], [3.2, 2.4, ‘c‘]]

   kdtree = Kdtree().build_tree(trainArray)

   #test1
   testNode = [1.6, 1.5]
   _, testType = kdtree.find_type(testNode)
   print("the type of ", testNode, "is ", testType)

   #test2
   testNode = [3.5, 2.7]
   _, testType = kdtree.find_type(testNode)
   print("the type of ", testNode, "is ", testType)

   #test3
   testNode = [4.3, 5.1]
   _, testType = kdtree.find_type(testNode)
   print("the type of ", testNode, "is ", testType)

測試結果

技術分享圖片

通過測試結果可知,kd樹可以有效地對輸入數據進行類型的識別。

討論

雖然通過測試結果正確,但代碼依然存在許多需要改進的地方,如kd樹的選擇,可以通過改進為紅黑平衡樹,來提高搜索速度。以及對於樹的每層切分平面的維度選擇,可以選擇各維度中方差最大的維度,這樣在此維度下的點分布更加分散,使後續的查找難度更小等等。

統計學習三:2.K近鄰法代碼實現(以最近鄰法為例)