KD-tree的原理以及構建與查詢操作的python實現
前幾天小組討論會上展示了kd-tree(k-dimension tree),感覺這玩意兒還挺有用的,所以學習了一下它的原理,然後把其中的構建kd-tree以及對應的查詢操作實現了一下,現在跟大家分享一下
首先說一下什麼是kd-tree把
不過首先得說一下bst(二叉查詢樹),遞迴定義如下:如果左子樹上的節點儲存的數值都小於根節點中儲存的數值,並且右子樹上的節點儲存的數值都大於根節點中儲存的數值,那麼這樣的二叉樹就是一顆二叉查詢樹
有了bst的概念,那麼kd-tree就 容易理解多了,首先kd-tree的節點中儲存的數值是一個k維的資料點,而bst的節點中儲存的可以視為是1維的資料點,kd-tree與bst不同的地方在於進行分支決策的時候,還需要選擇一個維度的值進行比較,選擇哪個維度呢?每個節點還需要維護一個split變數,表示進行分支決策的時候,選擇哪個維度的值進行比較,現在給出一個kd-tree節點的定義
- class KD_node:
- def __init__(self, point=None, split=None, LL = None, RR = None):
- """
- point:資料點
- split:劃分域
- LL, RR:節點的左兒子跟右兒子
- """
- self.point = point
- self.split = split
- self.left = LL
- self.right = RR
point就代表節點儲存的k維資料點,left,right分別代表指向左右兒子的指標,split代表劃分維度,在節點進行劃分之前,我們需要確定劃分維度,那麼怎麼選擇劃分維度呢,這又要從kd-tree的用途開始說起了
kd-tree
以上這兩種搜尋如果通過傳統方法來實現,那麼最壞情況下可能會窮舉 資料急中的所有點,這種方法的缺點就是完全沒有利用到資料集中蘊藏的結構資訊,當資料點很多時,搜尋效率不高;
事實上,實際資料集中的點一般時呈簇狀分佈的,所以,很多點我們是完全沒有必要遍歷的,索引樹的方法就是對將要搜尋的點進行空間劃分,空間劃分可能會有重疊,也可能沒有重疊,kd-tree就是劃分空間沒有重疊的索引樹
這樣說可能有一點亂,那我還是以“二分查詢”作為引入吧
如果給你一組資料 9 1 4 7 2 5 0 3 8
讓你查詢8,如果你挨個查詢,那麼將會把資料集都遍歷一遍,
如果你排一下序那現在資料集就變成了:0 1 2 3 4 5 6 7 8 9,其實我們進行了很多沒有必要的查詢,
如果我以5為分界點,那麼資料點就被分為了 兩個“簇” (0 1 2 3 4)和(6 7 8 9),如果我要查詢8,我根本久沒有必要進入第一個簇,直接進入第二個簇進行查詢,經過2次操作之後,就可以找到8了
把二分查詢中的資料點換成k維資料點,這樣的劃分就變成了我們剛才說的空間劃分,所以在這裡要搞清楚,空間劃分就是把資料點分類,“捱得近”的資料點就在一個空間裡面
好 現在回到剛才的劃分維度的選擇上,因為我要儘可能將相似的點放在一顆子樹裡面,所以kd-tree採取的思想就是計算所有資料點在每個維度上的數值的方差
然後方差最大的維度就作為當前節點的劃分維度,這樣做的原理其實就是:方差越大,說明這個維度上的資料波動越大,也就說明了他們就越不可能屬於同一個空間,需要在這個維度上對點進行劃分,這就是kd-tree節點選擇劃分維度的原理
先貼一張kd-tree的圖途中每個節點代表劃分點,標示維黑體的維度就是節點的劃分維度,可以看到對於任意節點來說,如果給定劃分維度split, 它的左子樹上的節點在split維度上的值一定比它在split維度上的值要小,右子樹上的節點在split維度上的值一定相應要大,所以說kd-tree實際上就是bst在多維空間上的拓展
好,扯了那麼多廢話,舉個例子來說一下kd-tree的構造
現在假設我有若干個二維空間的資料點(橫向為x軸,縱向為y軸)
通過第一次方差的計算,我們發現x維度上的方差比較大,所以,我們先選x軸為劃分維度,得到下面的點,黃色的點代表分割點,這裡要說明一下,分割點(也就是節點儲存的資料節點)一般取在分割維度上的值為中間值的點,下圖就是選了x維度上的值為中值的點作為切割點
現在我們又對x<x0 和x>=x0空間進行劃分,先看x>=x0這個子空間,很明顯,y軸方向上的資料波動要比x軸方向上的資料波動更大,所以這個空間中我們選的劃分維度為y維度,紅色節點為分割點
我們按照上面的方法,持續對空間中的點進行劃分,直到每個空間中只有一個點,這樣,一棵kd-tree就構成了
根據上面的介紹,黃色的節點就代表kd-tree的根節點,也就是第一個分割點;紅色的點代表位於第二層上的節點,剩下的以此類推
好了,現在附上建立kd-tree的Python程式碼
- def createKDTree(root, data_list):
- """
- root:當前樹的根節點
- data_list:資料點的集合(無序)
- return:構造的KDTree的樹根
- """
- LEN = len(data_list)
- if LEN == 0:
- return
- #資料點的維度
- dimension = len(data_list[0])
- #方差
- max_var = 0
- #最後選擇的劃分域
- split = 0;
- for i in range(dimension):
- ll = []
- for t in data_list:
- ll.append(t[i])
- var = computeVariance(ll)
- if var > max_var:
- max_var = var
- split = i
- #根據劃分域的資料對資料點進行排序
- data_list.sort(key=lambda x: x[split])
- #選擇下標為len / 2的點作為分割點
- point = data_list[LEN / 2]
- root = KD_node(point, split)
- root.left = createKDTree(root.left, data_list[0:(LEN / 2)])
- root.right = createKDTree(root.right, data_list[(LEN / 2 + 1):LEN])
- return root
- def computeVariance(arrayList):
- """
- arrayList:存放的資料點
- return:返回資料點的方差
- """
- for ele in arrayList:
- ele = float(ele)
- LEN = len(arrayList)
- array = numpy.array(arrayList)
- sum1 = array.sum()
- array2 = array * array
- sum2 = array2.sum()
- mean = sum1 / LEN
- #D[X] = E[x^2] - (E[x])^2
- variance = sum2 / LEN - mean**2
- return variance
說完了kd-tree的構建,現在再來說一下如何利用kd-tree進行最近鄰的查詢
基本的查詢思路是這樣的:
1.二叉查詢:從根節點開始進行查詢,直到葉子節點;在這個過程中,記錄最短的距離,和對應的資料點;同時維護一個棧,用來儲存經過的節點
2.回溯查詢:通過計算查詢點到分割平面的距離(這個距離比較的是分割維度上的值的差,並不是分割節點到分割平面上的距離,雖然兩者的值是相等的)與當前最短距離進行比較,決定是否需要進入節點的相鄰子空間進行查詢,為什麼需要這個判斷呢,我舉一個例子就大家可能就能明白了
途中的黑點為kd-tree中的資料點,五角星為查詢點,我們通過kd-tree的分支決策會將它分到坐上角的那部分空間,但並不是意味著它到那個空間中的點的距離最近
我們首先掃描到葉子節點,掃描的過程中記錄的最近點為p(5,4),最短距離為d, 現在開始回溯,假設分割的維度為ss,其實回溯的過程就是確定是否有必要進入相鄰子空間進行搜尋,確定的依據就是當前點到最近點的距離d是否大於當前點到分割面(在二維空間中實際上就是一條線)的距離L,如果d < L,那麼說明完全沒有必要進入到另一個子空間進行搜尋,直接繼續向上一層回溯;如果有d > L,那麼說明相鄰子空間中可能有距查詢點更近的點
- def findNN(root, query):
- """
- root:KDTree的樹根
- query:查詢點
- return:返回距離data最近的點NN,同時返回最短距離min_dist
- """
- #初始化為root的節點
- NN = root.point
- min_dist = computeDist(query, NN)
- nodeList = []
- temp_root = root
- ##二分查詢建立路徑
- while temp_root:
- nodeList.append(temp_root)
- dd = computeDist(query, temp_root.point)
- if min_dist > dd:
- NN = temp_root.point
- min_dist = dd
- #當前節點的劃分域
- ss = temp_root.split
- if query[ss] <= temp_root.point[ss]:
- temp_root = temp_root.left
- else:
- temp_root = temp_root.right
- ##回溯查詢
- while nodeList:
- #使用list模擬棧,後進先出
- back_point = nodeList.pop()
- ss = back_point.split
- print"back.point = ", back_point.point
- ##判斷是否需要進入父親節點的子空間進行搜尋
- if abs(query[ss] - back_point.point[ss]) < min_dist:
- if query[ss] <= back_point.point[ss]:
- temp_root = back_point.right
- else:
- temp_root = back_point.left
- if temp_root:
- nodeList.append(temp_root)
- curDist = computeDist(query, temp_root.point)
- if min_dist > curDist:
- min_dist = curDist
- NN = temp_root.point
- return NN, min_dist
- def computeDist(pt1, pt2):
- """
- 計算兩個資料點的距離
- return:pt1和pt2之間的距離
- """
- sum = 0.0
- for i in range(len(pt1)):
- sum = sum + (pt1[i] - pt2[i]) * (pt1[i] - pt2[i])
- return math.sqrt(sum)
為了驗證建立的樹是否正確以及最後的距離度量是否正確,我分別使用了樹的前序遍歷和knn來對比執行的結果
- def preorder(root):
- """
- KDTree的前序遍歷
- """
- print root.point
- if root.left:
- preorder(root.left)
- if root.right:
- preorder(root.right)
- def KNN(list, query):
- min_dist = 9999.0
- NN = list[0]
- for pt in list:
- dist = computeDist(query, pt)
- if dist < min_dist:
- NN = pt
- min_dist = dist
- return NN, min_dist
程式碼寫得比較急,而且測試的資料也比較簡單,所以難免會有bug,歡迎大家積極給我指出來
參考文獻:http://web.stanford.edu/class/cs106l/handouts/assignment-3-kdtree.pdf