1. 程式人生 > >KD-tree的原理以及構建與查詢操作的python實現

KD-tree的原理以及構建與查詢操作的python實現

前幾天小組討論會上展示了kd-tree(k-dimension tree),感覺這玩意兒還挺有用的,所以學習了一下它的原理,然後把其中的構建kd-tree以及對應的查詢操作實現了一下,現在跟大家分享一下

首先說一下什麼是kd-tree把

不過首先得說一下bst(二叉查詢樹),遞迴定義如下:如果左子樹上的節點儲存的數值都小於根節點中儲存的數值,並且右子樹上的節點儲存的數值都大於根節點中儲存的數值,那麼這樣的二叉樹就是一顆二叉查詢樹

有了bst的概念,那麼kd-tree就 容易理解多了,首先kd-tree的節點中儲存的數值是一個k維的資料點,而bst的節點中儲存的可以視為是1維的資料點,kd-tree與bst不同的地方在於進行分支決策的時候,還需要選擇一個維度的值進行比較,選擇哪個維度呢?每個節點還需要維護一個split變數,表示進行分支決策的時候,選擇哪個維度的值進行比較,現在給出一個kd-tree節點的定義

  1. class KD_node:  
  2.     def __init__(self, point=None, split=None, LL = None, RR = None):  
  3.         """ 
  4.         point:資料點 
  5.         split:劃分域 
  6.         LL, RR:節點的左兒子跟右兒子 
  7.         """
  8.         self.point = point  
  9.         self.split = split  
  10.         self.left = LL  
  11.         self.right = RR  

point就代表節點儲存的k維資料點,left,right分別代表指向左右兒子的指標,split代表劃分維度,在節點進行劃分之前,我們需要確定劃分維度,那麼怎麼選擇劃分維度呢,這又要從kd-tree的用途開始說起了

kd-tree

是一種對高維空間的資料點進行劃分的特殊資料結構主要應用就是高維空間的資料查詢,如:範圍搜尋和K近鄰(knn)搜尋,範圍搜尋就是給定查詢點和距離閾值,獲取在閾值範圍內的所有資料點;knn搜尋就是給定查詢點和搜尋點的數目n,查找出到搜尋點最近的n個點的數目;

以上這兩種搜尋如果通過傳統方法來實現,那麼最壞情況下可能會窮舉  資料急中的所有點,這種方法的缺點就是完全沒有利用到資料集中蘊藏的結構資訊,當資料點很多時,搜尋效率不高;

事實上,實際資料集中的點一般時呈簇狀分佈的,所以,很多點我們是完全沒有必要遍歷的,索引樹的方法就是對將要搜尋的點進行空間劃分,空間劃分可能會有重疊,也可能沒有重疊,kd-tree就是劃分空間沒有重疊的索引樹

這樣說可能有一點亂,那我還是以“二分查詢”作為引入吧

如果給你一組資料   9 1 4 7 2 5 0 3 8

讓你查詢8,如果你挨個查詢,那麼將會把資料集都遍歷一遍,

如果你排一下序那現在資料集就變成了:0 1 2 3 4 5 6 7 8 9,其實我們進行了很多沒有必要的查詢,

如果我以5為分界點,那麼資料點就被分為了 兩個“簇” (0 1 2 3 4)和(6 7 8 9),如果我要查詢8,我根本久沒有必要進入第一個簇,直接進入第二個簇進行查詢,經過2次操作之後,就可以找到8了

把二分查詢中的資料點換成k維資料點,這樣的劃分就變成了我們剛才說的空間劃分,所以在這裡要搞清楚,空間劃分就是把資料點分類,“捱得近”的資料點就在一個空間裡面

好 現在回到剛才的劃分維度的選擇上,因為我要儘可能將相似的點放在一顆子樹裡面,所以kd-tree採取的思想就是計算所有資料點在每個維度上的數值的方差

然後方差最大的維度就作為當前節點的劃分維度,這樣做的原理其實就是:方差越大,說明這個維度上的資料波動越大,也就說明了他們就越不可能屬於同一個空間,需要在這個維度上對點進行劃分,這就是kd-tree節點選擇劃分維度的原理

先貼一張kd-tree的圖


途中每個節點代表劃分點,標示維黑體的維度就是節點的劃分維度,可以看到對於任意節點來說,如果給定劃分維度split, 它的左子樹上的節點在split維度上的值一定比它在split維度上的值要小,右子樹上的節點在split維度上的值一定相應要大,所以說kd-tree實際上就是bst在多維空間上的拓展

好,扯了那麼多廢話,舉個例子來說一下kd-tree的構造

現在假設我有若干個二維空間的資料點(橫向為x軸,縱向為y軸)


通過第一次方差的計算,我們發現x維度上的方差比較大,所以,我們先選x軸為劃分維度,得到下面的點,黃色的點代表分割點,這裡要說明一下,分割點(也就是節點儲存的資料節點)一般取在分割維度上的值為中間值的點,下圖就是選了x維度上的值為中值的點作為切割點


現在我們又對x<x0 和x>=x0空間進行劃分,先看x>=x0這個子空間,很明顯,y軸方向上的資料波動要比x軸方向上的資料波動更大,所以這個空間中我們選的劃分維度為y維度,紅色節點為分割點


我們按照上面的方法,持續對空間中的點進行劃分,直到每個空間中只有一個點,這樣,一棵kd-tree就構成了


根據上面的介紹,黃色的節點就代表kd-tree的根節點,也就是第一個分割點;紅色的點代表位於第二層上的節點,剩下的以此類推

好了,現在附上建立kd-tree的Python程式碼

  1. def createKDTree(root, data_list):  
  2.     """ 
  3.     root:當前樹的根節點 
  4.     data_list:資料點的集合(無序) 
  5.     return:構造的KDTree的樹根 
  6.     """
  7.     LEN = len(data_list)  
  8.     if LEN == 0:  
  9.         return
  10.     #資料點的維度
  11.     dimension = len(data_list[0])  
  12.     #方差
  13.     max_var = 0
  14.     #最後選擇的劃分域
  15.     split = 0;  
  16.     for i in range(dimension):  
  17.         ll = []  
  18.         for t in data_list:  
  19.             ll.append(t[i])  
  20.         var = computeVariance(ll)  
  21.         if var > max_var:  
  22.             max_var = var  
  23.             split = i  
  24.     #根據劃分域的資料對資料點進行排序
  25.     data_list.sort(key=lambda x: x[split])  
  26.     #選擇下標為len / 2的點作為分割點
  27.     point = data_list[LEN / 2]  
  28.     root = KD_node(point, split)  
  29.     root.left = createKDTree(root.left, data_list[0:(LEN / 2)])  
  30.     root.right = createKDTree(root.right, data_list[(LEN / 2 + 1):LEN])  
  31.     return root  
  32. def computeVariance(arrayList):  
  33.     """ 
  34.     arrayList:存放的資料點 
  35.     return:返回資料點的方差 
  36.     """
  37.     for ele in arrayList:  
  38.         ele = float(ele)  
  39.     LEN = len(arrayList)  
  40.     array = numpy.array(arrayList)  
  41.     sum1 = array.sum()  
  42.     array2 = array * array  
  43.     sum2 = array2.sum()  
  44.     mean = sum1 / LEN  
  45.     #D[X] = E[x^2] - (E[x])^2
  46.     variance = sum2 / LEN - mean**2
  47.     return variance  

說完了kd-tree的構建,現在再來說一下如何利用kd-tree進行最近鄰的查詢

基本的查詢思路是這樣的:

1.二叉查詢:從根節點開始進行查詢,直到葉子節點;在這個過程中,記錄最短的距離,和對應的資料點;同時維護一個棧,用來儲存經過的節點

2.回溯查詢:通過計算查詢點到分割平面的距離(這個距離比較的是分割維度上的值的差,並不是分割節點到分割平面上的距離,雖然兩者的值是相等的)與當前最短距離進行比較,決定是否需要進入節點的相鄰子空間進行查詢,為什麼需要這個判斷呢,我舉一個例子就大家可能就能明白了


途中的黑點為kd-tree中的資料點,五角星為查詢點,我們通過kd-tree的分支決策會將它分到坐上角的那部分空間,但並不是意味著它到那個空間中的點的距離最近

我們首先掃描到葉子節點,掃描的過程中記錄的最近點為p(5,4),最短距離為d, 現在開始回溯,假設分割的維度為ss,其實回溯的過程就是確定是否有必要進入相鄰子空間進行搜尋,確定的依據就是當前點到最近點的距離d是否大於當前點到分割面(在二維空間中實際上就是一條線)的距離L,如果d < L,那麼說明完全沒有必要進入到另一個子空間進行搜尋,直接繼續向上一層回溯;如果有d > L,那麼說明相鄰子空間中可能有距查詢點更近的點

  1. def findNN(root, query):  
  2.     """ 
  3.     root:KDTree的樹根 
  4.     query:查詢點 
  5.     return:返回距離data最近的點NN,同時返回最短距離min_dist 
  6.     """
  7.     #初始化為root的節點
  8.     NN = root.point  
  9.     min_dist = computeDist(query, NN)  
  10.     nodeList = []  
  11.     temp_root = root  
  12.     ##二分查詢建立路徑
  13.     while temp_root:  
  14.         nodeList.append(temp_root)  
  15.         dd = computeDist(query, temp_root.point)  
  16.         if min_dist > dd:  
  17.             NN = temp_root.point  
  18.             min_dist = dd  
  19.         #當前節點的劃分域
  20.         ss = temp_root.split  
  21.         if query[ss] <= temp_root.point[ss]:  
  22.             temp_root = temp_root.left  
  23.         else:  
  24.             temp_root = temp_root.right  
  25.     ##回溯查詢
  26.     while nodeList:  
  27.         #使用list模擬棧,後進先出
  28.         back_point = nodeList.pop()  
  29.         ss = back_point.split  
  30.         print"back.point = ", back_point.point  
  31.         ##判斷是否需要進入父親節點的子空間進行搜尋
  32.         if abs(query[ss] - back_point.point[ss]) < min_dist:  
  33.             if query[ss] <= back_point.point[ss]:  
  34.                 temp_root = back_point.right  
  35.             else:  
  36.                 temp_root = back_point.left  
  37.             if temp_root:  
  38.                 nodeList.append(temp_root)  
  39.                 curDist = computeDist(query, temp_root.point)  
  40.                 if min_dist > curDist:  
  41.                     min_dist = curDist  
  42.                     NN = temp_root.point  
  43.     return NN, min_dist  
  44. def computeDist(pt1, pt2):  
  45.     """ 
  46.     計算兩個資料點的距離 
  47.     return:pt1和pt2之間的距離 
  48.     """
  49.     sum = 0.0
  50.     for i in range(len(pt1)):  
  51.         sum = sum + (pt1[i] - pt2[i]) * (pt1[i] - pt2[i])  
  52.     return math.sqrt(sum)  

為了驗證建立的樹是否正確以及最後的距離度量是否正確,我分別使用了樹的前序遍歷和knn來對比執行的結果
  1. def preorder(root):  
  2.     """ 
  3.     KDTree的前序遍歷 
  4.     """
  5.     print root.point  
  6.     if root.left:  
  7.         preorder(root.left)  
  8.     if root.right:  
  9.         preorder(root.right)  
  10. def KNN(list, query):  
  11.     min_dist = 9999.0
  12.     NN = list[0]  
  13.     for pt in list:  
  14.         dist = computeDist(query, pt)  
  15.         if dist < min_dist:  
  16.             NN = pt  
  17.             min_dist = dist  
  18.     return NN, min_dist  

程式碼寫得比較急,而且測試的資料也比較簡單,所以難免會有bug,歡迎大家積極給我指出來   

參考文獻:http://web.stanford.edu/class/cs106l/handouts/assignment-3-kdtree.pdf