1. 程式人生 > >基於kd樹的KNN演算法的實現

基於kd樹的KNN演算法的實現

記得大三初期,剛從大連理工大學回來,眼巴巴的望著同學各自都有著落了,就我一副“初出茅廬,不諳世事”的樣子,於是不得不覥著臉厚著皮去找老師,懇求他讓我去海洋所實習。他給我的第一份差事便是將幾個G的圖片裡的數字輸入到excel,我整整輸了一個國慶節假日。當時就在到處詢問,有沒有那種演算法可以讓自動識別圖片裡的數字,存入到excel中去,想來,那時的自己也是夠拼的。
如今這個自動識別數字的演算法算是寫出來了吧, 我至少可以這樣自我安慰到。
KNN演算法的理論算的上是最簡單最直觀的一種了,比起前幾次的支援向量機、貝葉斯、邏輯斯特迴歸那是簡單太多了,都不用推導半個公式。這周的核心都是在完成k-近鄰中kd樹的構建和搜尋,幾乎都是自己完成的,也沒有經過周密的測試,只是除錯調通了。
我想,它的用例定不止於此,但這個用例說出去可謂是最唬人的了。
識別下面的“圖片”為數字2 0 8

原理就不多講了,感興趣的網上都有,理論很簡單,只是構建和搜尋kd樹可能會有些麻煩,而kd樹只是為了讓它執行的更快,其實用最簡單粗暴的方法計算目標點與每個訓練集點的距離也未嘗不可。

程式執行的效果還行,識別近千個數字只錯了10個,錯誤率1%左右。效果如下,為了好看,我就僅截圖出識別幾個數字的效果:

下面為實現的程式,註釋寫的很明白,訓練資料在網上也不難找到:

KnnHelper.py

import numpy as np
'''
Created on 2017年7月17日

@author: fujianfei
'''

class KDNode(object):
    '''
    定義KD節點:
    point:節點裡面的樣本點,指的就是一個樣本點
    split:分割緯度(即用哪個緯度的資料進行切分,比如4維資料,split=3,則表示按照第4列的資料進行切分空間)
    left:節點的左子節點
    right:節點的右子節點
    '''
def __init__(self, point=None, split=None, left=None, right=None): ''' Constructor ''' self.point = point self.split = split self.left = left self.right = right class KDTree(object): ''' 定義: KDNode:kd-tree的節點 dimensions:資料的緯度 right:節點的右子節點 left:節點的左子節點 curr_axis:當前需要切分的緯度 next_axis:下一次需要切分的緯度 '''
def __init__(self, data=None): ''' Constructor ''' def createNode(split=None, data_set=None): ''' 建立KD節點 輸入值:split:分割緯度 data_set:需要分割的樣本點集合 返回值:KDNode:KD節點 ''' if len(data_set) == 0: # 資料集為空,作為遞迴的停止條件 return None #找到split維的中位數median,先對資料進行排序,按照split維的資料大小排序 data_set = list(data_set) data_set.sort(key=lambda x: x[split])#對data_set進行排序,lambda是隱函式,具體用法請百度。排序方式為按照split維的資料大小排序 data_set = np.array(data_set) median = len(data_set) // 2#//為python的整數除法,找到中間點的位置median,按照這個位置進行空間切分 #返回KD節點 #輸入的變數分別是: #data_set[median],中間點位置的樣本點,傳入KDNode即節點裡麵包含的資料 #split,該節點的緯度分度位置 #createNode(maxVar(data_set[:median]),data_set[:median]),該節點的左節點,maxVar(data_set[:median])為左節點的緯度分度位置,data_set[:median]為左節點包含的空間裡的所有資料 #同理,createNode(maxVar(data_set[median+1:]),data_set[median+1:]),為右節點。 #用的是函式的遞迴建立樹,因為要不斷的呼叫函式,這個方法速度不快,用基本語句(判斷、迴圈)去構建樹的方法會更快 return KDNode(data_set[median], split, createNode(maxVar(data_set[:median]),data_set[:median]), createNode(maxVar(data_set[median+1:]),data_set[median+1:])) def maxVar(data_set=None): ''' 按緯度計算樣本集的最大方差緯度 輸入值:data_set:樣本集 輸出值:split:最大方差的緯度,作為createNode的輸入值 ''' if len(data_set) == 0: # 資料集為空,作為遞迴的停止條件 return 0 data_mean = np.mean(data_set,axis=0)#axis=0表示按列求均值 mean_differ = data_set - data_mean#均值差 data_var = np.sum(mean_differ ** 2,axis=0)/len(data_set)#按列求均值差平方之和,再除以樣本數,便是方差 re = np.where(data_var == np.max(data_var))#尋找方差最大的位置,也就是第幾緯方差最大,返回它 return re[0][0] self.root = createNode(maxVar(data),data)#定義根節點,分割緯度是使得樣本點方差最大的緯度,需要分割的樣本點為全資料 def computeDist(pt1, pt2): """ 計算兩個資料點的距離 return:pt1和pt2之間的距離 """ sum = 0.0 for i in range(len(pt1)): sum = sum + (pt1[i] - pt2[i]) * (pt1[i] - pt2[i]) return np.math.sqrt(sum) def preOrder(root): ''' KD樹的前序遍歷 ''' print(root.point) if root.left: preOrder(root.left) if root.right: preOrder(root.right) def updateNN(min_dist_array=None, tmp_dist=0.0, NN=None, tmp_point=None, k=1): ''' /更新近鄰點和對應的最小距離集合 min_dist_array為最小距離的集合 NN為近鄰點的集合 tmp_dist和tmp_point分別是需要更新到min_dist_array,NN裡的近鄰點和距離 ''' if tmp_dist <= np.min(min_dist_array) : for i in range(k-1,0,-1) : min_dist_array[i] = min_dist_array[i-1] NN[i] = NN[i-1] min_dist_array[0] = tmp_dist NN[0] = tmp_point return NN,min_dist_array for i in range(k) : if (min_dist_array[i] <= tmp_dist) and (min_dist_array[i+1] >= tmp_dist) : #tmp_dist在min_dist_array的第i位和第i+1位之間,則插入到i和i+1之間,並把最後一位給剔除掉 for j in range(k-1,i,-1) : #range反向取值 min_dist_array[j] = min_dist_array[j-1] NN[j] = NN[j-1] min_dist_array[i+1] = tmp_dist NN[i+1] = tmp_point break return NN,min_dist_array def searchKDTree(KDTree=None, target_point=None, k=1): ''' /搜尋kd樹 /輸入值:KDTree,kd樹;target_point,目標點;k,距離目標點最近的k個點的k值 /輸出值:k_arrayList,距離目標點最近的k個點的集合陣列 ''' if k == 0 : return None #從根節點出發,遞迴地向下訪問kd樹。若目標點當前維的座標小於切分點的座標,則移動到左子節點,否則移動到右子節點 tempNode = KDTree.root#定義臨時節點,先從根節點出發 NN = [tempNode.point] * k#定義最鄰近點集合,k個元素,按照距離遠近,由近到遠。初始化為k個根節點 min_dist_array = [float("inf")] * k#定義近鄰點與目標點距離的集合.初始化為無窮大 # for i in range(k) : # NN[i] = tempNode.point#定義最鄰近點集合,k個元素,按照距離遠近,由近到遠。初始化為k個根節點以下往左的集合 # min_dist_array[i] = computeDist(NN[i],target_point)#定義近鄰點與目標點距離的集合 # tempNode = tempNode.left nodeList = []#我們是用二分查詢建立路徑,定義依次查詢節點的list def buildSearchPath(tempNode=None, nodeList=None, min_dist_array=None, NN=None, target_point=None): ''' P:此方法是用來建立以tempNode為根節點,以下所有節點的查詢路徑,並將它們存放到nodeList中 nodeList為一系列節點的順序組合,按此先後順序搜尋最鄰近點 tempNode為"根節點",即以它為根節點,查詢它以下所有的節點(空間) ''' while tempNode : nodeList.append(tempNode) split = tempNode.split#節點的分割緯度 point = tempNode.point#節點包含的資料,當前例項點 tmp_dist = computeDist(point,target_point) if tmp_dist < np.max(min_dist_array) : #小於min_dist_array中最大的距離 NN,min_dist_array = updateNN(min_dist_array, tmp_dist, NN, point, k)#更新最小距離和最鄰近點 if target_point[split] <= point[split] : #如果目標點當前維的值小於等於切分點的當前維座標值,移動到左節點 tempNode = tempNode.left else : #如果目標點當前維的值大於切分點的當前維座標值,移動到右節點 tempNode = tempNode.right return NN,min_dist_array #建立查詢路徑 NN,min_dist_array = buildSearchPath(tempNode,nodeList,min_dist_array, NN, target_point) #回溯查詢 while nodeList : back_node = nodeList.pop()#將nodeList裡的元素從後往前一個個推出來 split = back_node.split point = back_node.point #判斷是否需要進入父節點搜素 #如果當前緯度,目標點減例項點大於最小距離,就沒必要進入父節點搜素了 #因為目標點到切割超平面的距離很大,那鄰近點肯定不在那個切割的空間裡,即沒必要進入那個空間搜素了 if not abs(target_point[split] - point[split]) >= np.max(min_dist_array) : #判斷是搜尋左子節點,還是搜尋右子節點 if (target_point[split] <= point[split]) : #如果目標點在左子節點的空間,則搜尋右子節點,檢視右節點是否有更鄰近點 tempNode = back_node.right else : #如果目標點在右子節點的空間,則搜尋左子節點,檢視左節點是否有更鄰近點 tempNode = back_node.left if tempNode : #把tempNode(此時它為另一個全新的未搜素的空間,需要將它放入nodeList,進行最近鄰搜尋)放入nodeList #nodeList.append(tempNode) #不能單純地將tempNode存放到nodeList,這樣下次只會搜尋這一個節點 #因為tempNode可做為一個全新的空間,故而需重新以它為根節點,構建查詢路徑,搜尋它名下所有的節點 NN,min_dist_array = buildSearchPath(tempNode,nodeList,min_dist_array, NN, target_point) # curr_dist = computeDist(tempNode.point,target_point) #是否該節點為更鄰近點,如果是,賦值給最鄰近點 # if curr_dist < np.max(min_dist_array) : # NN,min_dist_array = updateNN(min_dist_array, curr_dist, NN, tempNode.point, k)#更新最小距離和最鄰近點 return NN,min_dist_array def classify0(inX, dataSet, labels, k): ''' k近鄰演算法的分類器 \輸入: inX:目標點 dataSet:訓練點集合 labels:訓練點對應的標籤 k:k值 \這個方法的目的:已知訓練點dataSet和對應的標籤labels,確定目標點inX對應的labels ''' kd = KDTree(dataSet)#構建dataSet的kd樹 NN,min_dist_array = searchKDTree(kd, inX, k)#搜尋kd樹,返回最近的k個點的集合NN,和對應的距離min_dist_array dataSet = dataSet.tolist() voteIlabels = [] #多數投票法則確定inX的標籤,為防止邊界處分類不準的情況,以距離的倒數為權重,即距離越近,權重越大,越該認為inX是屬於該類 for i in range(k) : #找到每個近鄰點對應的標籤 nni = list(NN[i]) voteIlabels.append(labels[dataSet.index(nni)]) # #開始記數,加權重的方法 # uniques = np.unique(voteIlabels) # counts = [0.0] * len(uniques) # for i in range(len(voteIlabels)) : # for j in range(len(uniques)) : # if voteIlabels[i] == uniques[j] : # counts[j] = counts[j] + uniques[j] / min_dist_array[i] #權重為距離的倒數 # break #開始記數,不加權重的方法 uniques, counts = np.unique(voteIlabels, return_counts=True) return uniques[np.argmax(counts)]

**

HandWriting.py

**

import numpy as np
from os import listdir
from KNN import KnnHelper
'''
Created on 2017年7月23日

@author: fujianfei
'''
def img2vector(filename):
    '''
    \將32x32影象轉化為1x1024的向量
    '''
    returnVect = np.zeros((1,1024))
    fr = open(filename)
    for i in range(32) :
        lineStr = fr.readline()
        for j in range(32) :
            returnVect[0,32*i+j] = int(lineStr[j])
    return returnVect

def handwritingClassTest():
    '''
    \識別手寫數字
    '''
    hwLabels = []#定義訓練資料對應的標籤集合,即數字0-9
    trainingFileList = listdir('trainingDigits')#獲取trainingDigits目錄下所有的檔名,存在trainingFileList中
    m = len(trainingFileList)
    trainingMat = np.zeros((m,1024))
    for i in range(m) :
        fileNameStr = trainingFileList[i]
        fileStr = fileNameStr.split('.')[0]
        classNumStr = int(fileStr.split('_')[0])
        hwLabels.append(classNumStr)
        trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr)#將檔案轉化為矩陣
    testFileList = listdir('testDigits')#獲取testDigits目錄下所有的檔名,存在testFileList中
    errorCount = 0.0
    mTest = len(testFileList)
    for i in range(mTest) :
        fileNameStr = testFileList[i]
        fileStr = fileNameStr.split('.')[0]
        classNumStr = int(fileStr.split('_')[0])
        vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)#將檔案轉化為矩陣
        vectorUnderTest = list(vectorUnderTest[0])
        classifierResult = KnnHelper.classify0(vectorUnderTest, trainingMat, hwLabels, 3)
        print("演算法識別的數字為  : %d , 真實的數字為 : %d " % (classifierResult, classNumStr))
        if (classifierResult != classNumStr) : errorCount += 1.0
    print("\n總共出錯的次數為  : %d" % errorCount)
    print("\n出錯率為  : %f" % (errorCount/float(mTest)))

_init_.py

import numpy as np
from KNN import KnnHelper,HandWriting

# data = [[4,1,3,5],[3,6,5,7],[5,2,6.5,5],[4.8,4.2,5,8],[1,1,8,6],[1,6,5,3],[4.1,3.7,2,5],[4.7,4.1,5,9],[2,4,6,8.7]]  # samples
# kd = KnnHelper.KDTree(data)
# # KnnHelper.preOrder(kd.root)
# ret = KnnHelper.searchKDTree(kd, [4.8,3.8,2,4], 9)
# print (ret)
HandWriting.handwritingClassTest()