1. 程式人生 > >機器學習——詳解KD-Tree原理

機器學習——詳解KD-Tree原理

本文始發於個人公眾號:TechFlow,原創不易,求個關注


今天是機器學習的第15篇文章,之前的文章當中講了Kmeans的相關優化,還講了大名鼎鼎的EM演算法。有些小夥伴表示喜歡看這些硬核的,於是今天上點硬菜,我們來看一個機器學習領域經常用到的資料結構——KD-Tree。

從線段樹到KD樹

在講KD樹之前,我們先來了解一下線段樹的概念。線段樹在機器學習領域當中不太常見,作為高效能維護的資料結構,經常出現在各種演算法比賽當中。線段樹的本質是一棵維護一段區間的平衡二叉樹。

比如下圖就是一個經典的線段樹:

從下圖當中我們不難看出來,這棵線段樹維護的是一個區間內的最大值。比如樹根是8,維護的是整個區間的最大值,每一箇中間節點的值都是以它為樹根的子樹中所有元素的最大值。

通過線段樹,我們可以在 的時間內計算出某一個連續區間的最大值。比如我們來看下圖:

當我們要求被框起來的區間中的最大值,我們只需要找到能夠覆蓋這個區間的中間節點就行。我們可以發現被紅框框起來的兩個節點的子樹剛好覆蓋這個區間,於是整個區間的最大值,就是這兩個元素的最大值。這樣,我們就把一個需要 查詢的問題降低成了 ,不但如此,我們也可以做到 複雜度內的更新,也就是說我們不但可以快速查詢,還可以更新線段當中的元素。

當然線段樹的應用非常廣泛,也有許多種變體,這裡我們不過多深入,感興趣的同學可以期待一下週三的演算法與資料結構專題,在之後的文章當中會為大家分享線段樹的相關內容。在這裡,我們只需要有一個大概的印象,線段樹究竟完成的是什麼樣的事情即可。

線段樹維護的是一個線段,也就是區間內的元素,也就是說維護的是一個一維的序列。如果我們將資料的維度擴充一下,擴充到多維呢?

是的,你沒有猜錯,從某種程度上來說,我們可以把KD-Tree看成是線段樹拓展到多維空間當中的情況。

KD-Tree定義

我們來看一下KD-Tree的具體定義,這裡的K指的是K維空間,D自然就是dimension,也就是維度,也就是說KD-Tree就是K維度樹的意思。

在我們構建線段樹的時候,其實是一個遞迴的建樹過程,我們每次把當前的線段一分為二,然後用分成兩半的資料分別構建左右子樹。我們可以簡單寫一下虛擬碼,來更直觀地感受一下:

class Node:
    def __init__(self, value, lchild, rchild):
        self.value = value
        self.lchild = lchild
        self.rchild = rchild   
        
def build(arr):
    n = len(arr):
    left, right = arr[: n//2], arr[n//2:]
    lchild, rchild = self.build(left), self.build(right)
    return Node(max(lchild.value, rchild.value), lchild, rchild)

我們來看一個二維的例子,在一個二維的平面當中分佈著若干個點。

我們首先選擇一個維度將這些資料一分為二,比如我們選擇x軸。我們對所有資料按照x軸的值排序,選出其中的中點進行一分為二。

在這根線左右兩側的點被分成了兩棵子樹,對於這兩個部分的資料來說,我們更換一個維度,也就是選擇y軸進行劃分。一樣,我們先排序,然後找到中間的點,再次一分為二。我們可以得到:

我們重複上述過程,一直將點分到不能分為止,為了能更好地看清楚,我們對所有資料標上座標(並不精確)。

如果我們把空間看成是廣義的區間,那麼它和線段樹的原理是一樣的。最後得到的也是一棵完美二叉樹,因為我們每次都選擇了資料集的中點進行劃分,可以保證從樹根到葉子節點的長度不會超過

我們代入上面的座標之後,我們最終得到的KD-Tree大概是下面這個樣子:

KD-Tree 建樹

在建樹的過程當中,我們的樹深每往下延伸一層,我們就會換一個維度作為衡量標準。原因也很簡單,因為我們希望這棵樹對於這K維空間都有很好的表達能力,方便我們根據不同的維度快速查詢。

在一些實現當中,我們會計算每一個維度的方差,然後選擇方差較大的維度進行切分。這樣做自然是因為方差較大的維度說明資料相對分散,切分之後可以把資料區分得更加明顯。但我個人覺得這樣做意義不是很大,畢竟計算方差也是一筆開銷。所以這裡我們選擇了最樸素的方法——輪流選擇。

也就是說我們從樹根開始,選擇第0維作為排序和切分資料的依據,然後到了樹深為1的這一層,我們選擇第一維,樹深為2的這一層,我們選擇第二維,以此類推。當樹深超過了K的時候,我們就對樹深取模。

明確了這一點之後,我們就可以來寫KD-Tree的建樹程式碼了,和上面二叉樹的程式碼非常相似,只不過多了維度的處理而已。

    # 外部暴露介面
    def build_model(self, dataset):
        self.root = self._build_model(dataset)
        # 先忽略,容後再講
        self.set_father(self.root, None)

    # 內部實現的介面
    def _build_model(self, dataset, depth=0):
        if len(dataset) == 0:
            return None

        # 通過樹深對K取模來獲得當前對哪一維切分
        axis = depth % self.K
        m = len(dataset) // 2
        # 根據axis這一維排序
        dataset = sorted(dataset, key=lambda x: x[axis])
        # 將資料一分為二
        left, mid, right = dataset[:m], dataset[m], dataset[m+1:]

        # 遞迴建樹
        return KDTree.Node(
            mid[axis],
            mid,
            axis,
            depth,
            len(dataset),
            self._build_model(left, depth+1),
            self._build_model(right, depth+1)
        )

這樣我們就建好了樹,但是在後序的查詢當中我們需要訪問節點的父節點,所以我們需要為每一個節點都賦值指向父親節點的指標。這個值我們可以寫在建樹的程式碼裡,但是會稍稍複雜一些,所以我把它單獨拆分了出來,作為一個獨立的函式來給每一個節點賦值。對於根節點來說,由於它沒有父親節點,所以賦值為None。

我們來看下set_father當中的內容,其實很簡單,就是一個樹的遞迴遍歷:

def set_father(self, node, father):
    if node is None:
        return
    # 賦值
    node.father = father
    # 遞迴左右
    self.set_father(node.lchild, node)
    self.set_father(node.rchild, node)

快速批量查詢

KD-Tree建樹建好了肯定是要來用的,它最大的用處是可以在單次查詢中獲得距離樣本最近的若干個樣本。在分散均勻的資料集當中,我們可以在 的時間內完成查詢,但是對於特殊情況可能會長一些,但是也比我們通過樸素的方法查詢要快得多。

我們很容易發現,KD-Tree一個廣泛的使用場景是用來優化KNN演算法。我們在之前介紹KNN演算法的文章當中曾經提到過,KNN演算法在預測的時候需要遍歷整個資料集,然後計算資料集中每一個樣本與當前樣本的距離,選出最近的K個來,這需要大量的開銷。而使用KD-Tree,我們可以在一次查詢當中直接查詢到K個最近的樣本,因此大大提升KNN演算法的效率。

那麼,這個查詢操作又是怎麼實現的呢?

這個查詢基於遞迴實現,因此對於遞迴不熟悉的小夥伴,可能初看會比較困難,可以先閱讀一下之前關於遞迴的文章。

首先我們先通過遞迴查詢到KD-Tree上的葉子節點,也就是找到樣本所在的子空間。這個查詢應該非常容易,本質上來說我們就是將當前樣本不停地與分割線進行比較,看看是在分割線的左側還是右側。和二叉搜尋樹的元素查詢是一樣的:

    def iter_down(self, node, data):
        # 如果是葉子節點,則返回
        if node.lchild is None and node.rchild is None:
            return node
        # 如果左節點為空,則遞迴右節點
        if node.lchild is None:
            return self.iter_down(node.rchild, data)
        # 同理,遞迴左節點
        if node.rchild is None:
            return self.iter_down(node.rchild, data)
        # 都不為空則和分割線判斷是左還是右
        axis = node.axis
        next_node = node.lchild if data[axis] <= node.boundray else node.rchild
        return self.iter_down(next_node, data)

我們找到了葉子節點,其實代表樣本空間當中的一小塊空間。

我們來實際走一下整個流程,假設我們要查詢3個點。首先,我們會建立一個候選集,用來儲存答案。當我們找到葉子節點之後,這個區域當中只有一個點,我們把它加入候選集。

在上圖當中紫色的x代表我們查詢的樣本,我們查詢到的葉子節點之後,在兩種情況下我們會把當前點加入候選集。第一種情況是候選集還有空餘,也就是還沒有滿K個,這裡的K是我們查詢的數量,也就是3。第二種情況是當前點到樣本的距離小於候選集中最大的一個,那麼我們需要更新候選集。

這個點被我們訪問過之後,我們會打上標記,表示這個點已經訪問過了。這個時候我們需要判斷,整棵樹當中的搜尋是否已經結束,如果當前節點已經是根節點了,說明我們的遍歷結束了,那麼返回候選集,否則說明還沒有,我們需要繼續搜尋。上圖當中我們用綠色表示樣本被放入了候選集當中,黃色表示已經訪問過。

由於我們的搜尋還沒有結束,所以需要繼續搜尋。繼續搜尋需要判斷樣本和當前分割線的距離來判斷和分割線的另一側有沒有可能存在答案。由於葉子節點沒有另一側,所以作罷,我們往上移動一個,跳轉到它的父親節點。

我們計算距離並且檢視候選集,此時候選集未滿,我們加入候選集,標記為已經訪問過。它雖然存在分割線,但是也沒有另一側的節點,所以也跳過。

我們再往上,遍歷到它的父親節點,我們執行同樣的判斷,發現此時候選集還有空餘,於是將它繼續加入答案:

但是當我們判斷到分割線距離的時候,我們發現這一次樣本到分割線的舉例要比之前候選集當中的最大距離要小,所以分割線的另一側很有可能存在答案:

這裡的d1是樣本到分割線的距離,d2是樣本到候選集當中最遠點的距離。由於到分割線更近,所以分割線的另一側很有可能也存在答案,這個時候我們需要搜尋分割線另一側的子樹,一直搜尋到葉子節點。

我們找到了葉子節點,計算距離,發現此時候選集已經滿了,並且它的距離大於候選集當中任何一個答案,所以不能構成新的答案。於是我們只是標記它已經訪問過,並不會加入候選集。同樣,我們繼續往上遍歷,到它的父節點:

比較之後發現,data到它的距離小於候選集當中最大的那個,於是我們更新候選集,去掉距離大於它的答案。然後我們重複上述的過程,直到根節點為止。

由於後面沒有更近的點,所以候選集一直沒有更新,最後上圖當中的三個打了綠標的點就是答案。

我們把上面的流程整理一下,就得到了遞迴函式當中的邏輯,我們用Python寫出來其實已經和程式碼差不多了:

def query(node, data, answers, K):
    # 判斷node是否已經訪問過
    if node.visited:
        # 標記訪問
        node.visited = True
        # 計算data到node中點的距離
        dis = cal_dis(data, node.point)
        # 如果小於答案中最大值則更新答案
        if dis < max(answers):
            answers.update(node.point)
        # 計算data到分割線的距離
        dis = cal_dis(data, node.split)
        # 如果小於最長距離,說明另一側還可能有答案
        if dis < max(answers):
            # 獲取當前節點的兄弟節點
            brother = self.get_brother(node)
            if brother is not None:
                # 往下搜尋到葉子節點,從葉子節點開始尋找
                leaf = self.iter_down(brother, data)
                if leaf is not None:
                    return self.query(leaf, data, answers, K)
        # 如果已經到了根節點了,退出
        if node is root:
            return answers
        # 遞迴父親節點
        return self.query(node.father, data, answers, K)
    else:
        if node is root:
            return answers
        return self.query(node.father, data, answers, K)

最終寫成的程式碼和上面這段並沒有太多的差別,在得到距離之後和答案當中的最大距離進行比較的地方,我們使用了優先佇列。其他地方几乎都是一樣的,我也貼上來給大家感受一下:

def _query_nearest_k(self, node, path, data, topK, K):
    # 我們用set記錄訪問過的路徑,而不是直接在節點上標記
    if node not in path:
        path.add(node)
        # 計算歐氏距離
        dis = KDTree.distance(node.value, data)
        if (len(topK) < K or dis < topK[-1]['distance']):
            topK.append({'node': node, 'distance': dis})
            # 使用優先佇列獲取topK
            topK = heapq.nsmallest(K, topK, key=lambda x: x['distance'])
        axis = node.axis
        # 分割線都是直線,直接計算座標差
        dis = abs(data[axis] - node.boundray)
        if len(topK) < K or dis <= topK[-1]['distance']:
            brother = self.get_brother(node, path)
            if brother is not None:
                next_node = self.iter_down(brother, data)
                if next_node is not None:
                    return self._query_nearest_k(next_node, path, data, topK, K)
        if node == self.root:
            return topK
        return self._query_nearest_k(node.father, path, data, topK, K)
    else:
        if node == self.root:
            return topK
        return self._query_nearest_k(node.father, path, data, topK, K)

這段邏輯大家應該都能看明白,但是有一個疑問是,我們為什麼不在node裡面加一個visited的欄位,而是通過傳入一個set來維護訪問過的節點呢?這個邏輯只看程式碼是很難想清楚的,必須要親手實驗才會理解。如果在node當中加入一個欄位當然也是可以的,如果這樣做的話,在我們執行查詢之後必須得手動再執行一次遞迴,將樹上所有節點的node全部置為false,否則下一次查詢的時候,會有一些節點已經被標記成了True,顯然會影響結果。查詢之後將這些值手動還原會帶來開銷,所以才轉換思路使用set來進行訪問判斷。

這裡的iter_down函式和我們上面貼的查詢葉子節點的函式是一樣的,就是查詢當前子樹的葉子節點。如果我沒記錯的話,這也是我們文章當中第一次出現在遞迴當中呼叫另一個遞迴的情況。對於初學者而言,這在理解上可能會相對困難一些。我個人建議可以親自動手試一試在紙上畫一個kd-tree進行手動模擬試一試,自然就能知道其中的執行邏輯了。這也是一個思考和學習非常好用的方法。

優化

當我們理解了整個kd-tree的建樹和查詢的邏輯之後,我們來考慮一下優化。

這段程式碼看下來初步可以找到兩個可以優化的地方,第一個地方是我們建樹的時候。我們每次遞迴的時候由於要將資料一分為二,我們是使用了排序的方法來實現的,而每次排序都是 的複雜度,這其實是不低的。其實仔細想想,我們沒有必要排序,我們只需要選出根據某個軸排序前n/2個數。也就是說這是一個選擇問題,並不是排序問題,所以可以想到我們可以利用之前講過的快速選擇的方法來優化。使用快速選擇,我們可以在 的時間內完成資料的拆分。

另一個地方是我們在查詢K個鄰近點的時候,我們使用了優先佇列維護的候選集當中的答案,方便我們對答案進行更新。同樣,優先佇列獲取topK也是 的複雜度。這裡也是可以優化的,比較好的思路是使用堆來代替。可以做到 的插入和彈出,相比於heapq的nsmallest方法要效率更高。

總結

到這裡,我們關於KD-tree的原理部分已經差不多講完了,我們有了建樹和查詢功能之後就可以用在KNN演算法上進行優化了。但是我們現在的KD-tree只支援建樹以及查詢,如果我們想要插入或者刪除集合當中的資料應該怎麼辦?難道每次修改都重新建樹嗎?這顯然不行,但是插入和刪除節點都會引起樹結構的變化很有可能導致樹不再平衡,這個時候我們應該怎麼辦呢?

我們先賣個關子,相關的內容將會放到下一篇文章當中,感興趣的同學不要錯過哦。

最後,我把KD-tree完整的程式碼放在了ubuntu.paste上,想要檢視完整原始碼的同學請在公眾號內回覆kd-tree進行檢視。

今天的文章就是這些,如果覺得有所收穫,請順手點個關注或者轉發吧,你們的舉手之勞對我來說很重要。

相關推薦

機器學習——KD-Tree原理

本文始發於個人公眾號:TechFlow,原創不易,求個關注 今天是機器學習的第15篇文章,之前的文章當中講了Kmeans的相關優化,還講了大名鼎鼎的EM演算法。有些小夥伴表示喜歡看這些硬核的,於是今天上點硬菜,我們來看一個機器學習領域經常用到的資料結構——KD-Tree。 從線段樹到KD樹 在講KD樹之前,

機器學習 | GBDT梯度提升樹原理,看完再也不怕面試了

本文始發於個人公眾號:TechFlow,原創不易,求個關注 今天是機器學習專題的第30篇文章,我們今天來聊一個機器學習時代可以說是最厲害的模型——GBDT。 雖然文無第一武無第二,在機器學習領域並沒有什麼最厲害的模型這一說。但在深度學習興起和流行之前,GBDT的確是公認效果最出色的幾個模型之一。雖然現在

機器學習 | GBDT在分類場景中的應用原理與公式推導

本文始發於個人公眾號:**TechFlow**,原創不易,求個關注 今天是**機器學習專題**的第31篇文章,我們一起繼續來聊聊GBDT模型。 在上一篇文章當中,我們學習了GBDT這個模型在迴歸問題當中的原理。GBDT最大的特點就是對於損失函式的降低不是通過調整模型當中已有的引數實現的,若是通過

機器學習筆記之八—— knn-最簡單的機器學習演算法以及KD原理

上一節結束了線性迴歸、邏輯迴歸,今天一節來介紹機器學習中最簡單的演算法:    K近鄰(KNN,全稱K-nearst Neighbor)       概述:判斷一個樣本的label只需要判斷該樣本週圍其他樣本的label。簡言之,朋

機器學習】KNN分類的概念、誤差率及其問題

勿在浮沙築高臺 KNN概念         KNN(K-Nearest Neighbors algorithm)是一種非引數模型演算法。在訓練資料量為N的樣本點中,尋找最近鄰測試資料x的K個樣本,然

機器學習無約束優化問題:梯度下降、牛頓法、擬牛頓法

無約束優化問題是機器學習中最普遍、最簡單的優化問題。 x∗=minxf(x),x∈Rn 1.梯度下降 梯度下降是最簡單的迭代優化演算法,每一次迭代需求解一次梯度方向。函式的負梯度方向代表使函式值減小最快的方向。它的思想是沿著函式負梯度方向移動逐步逼

[機器學習]分類演算法--決策樹演算法

前言 演算法的有趣之處在於解決問題,否則僅僅立足於理論,便毫無樂趣可言; 不過演算法的另一特點就是容易嚇唬人,又是公式又是圖示啥的,如果一個人數學理論知識過硬,靜下心來看,都是可以容易理解的,紙老虎一個,不過這裡的演算法主要指的應用型演算法

機器學習】SMO演算法剖析

CSDN−勿在浮沙築高臺 本文力求簡化SMO的演算法思想,畢竟自己理解有限,無奈還是要拿一堆公式推來推去,但是靜下心看完本篇並隨手推導,你會迎刃而解的。推薦參看SMO原文中的虛擬碼。 1.SMO概念 上一篇部落格已經詳細介紹了SVM原理,為了方便求解,把原

機器學習】線性迴歸、梯度下降、最小二乘的幾何和概率解釋

線性迴歸 即線性擬合,給定N個樣本資料(x1,y1),(x2,y2)....(xN,yN)其中xi為輸入向量,yi表示目標值,即想要預測的值。採用曲線擬合方式,找到最佳的函式曲線來逼近原始資料。通過使得代價函式最小來決定函式引數值。 採用斯坦福大學公開課的

機器學習——經典聚類演算法Kmeans

本文始發於個人公眾號:**TechFlow**,原創不易,求個關注 今天是機器學習專題的第12篇文章,我們一起來看下Kmeans聚類演算法。 在上一篇文章當中我們討論了KNN演算法,KNN演算法非常形象,通過距離公式找到最近的K個鄰居,通過鄰居的結果來推測當前的結果。今天我們要來看的演算法同樣非常直觀,

【深度學習系列】卷積神經網路CNN原理(一)——基本原理

轉自:https://www.cnblogs.com/charlotte77/p/7759802.html 上篇文章我們給出了用paddlepaddle來做手寫數字識別的示例,並對網路結構進行到了調整,提高了識別的精度。有的同學表示不是很理解原理,為什麼傳統的機

B+tree以及mysql的索引原理

最近在學mysq的索引,網上查了很多資料但都沒有很好理解的,現在先講講b+tree 動態查詢樹主要有:二叉查詢樹(Binary Search Tree),平衡二叉查詢樹(Balanced Binary Search Tree),紅黑樹 (Red-Black Tree )

JMX學習

-s agen 技術 操作 三層架構 out javax optional 配置 一、概述:   JMX(Java Management Extensions,即Java管理擴展)是一個為應用程序、設備、系統等植入管理功能的框架。   JMX的核心類是MBean(準確說是接

(轉)Java JVM 工作原理和流程

移植 獲得 代碼 適配 調用 tac 階段 main方法 等待 作為一名Java使用者,掌握JVM的體系結構也是必須的。說起Java,人們首先想到的是Java編程語言,然而事實上,Java是一種技術,它由四方面組成:Java編程語言、Java類文件格式、Java虛擬機和Ja

Java JVM 工作原理和流程

str literal 狀態 應用 流程 href ctu 局部變量 自定義 作為一名Java使用者,掌握JVM的體系結構也是必須的。說起Java,人們首先想到的是Java編程語言,然而事實上,Java是一種技術,它由四方面組成:Java編程語言、Java類文件格式、Jav

純幹貨iptables工作原理以及使用方法

rip -a sports 公網 寫法 內網ip 行處理 外部 是否 簡介 網絡中的防火墻,是一種將內部和外部網絡分開的方法,是一種隔離技術。防火墻在內網與外網通信時進行訪問控制,依據所設置的規則對數據包作出判斷,最大限度地阻止網絡中不法分子破壞企業網絡,從而加強了企業網絡

一文讀懂機器學習大殺器XGBoost原理

結構 近似算法 機器 form con gin fff .cn tran http://blog.itpub.net/31542119/viewspace-2199549/ XGBoost是boosting算法的其中一種。Boosting算法的思想是將許多弱分類器集成在

Python定時任務框架APScheduler學習

情況 類型 container 邏輯 專業 取值 控制 scheduled 執行器 轉載一篇文章,講解了Python定時任務框架APScheduler的使用,原文地址:https://www.cnblogs.com/luxiaojun/p/6567132.html,內容如下

Hadoop偽分佈安裝+MapReduce執行原理+基於MapReduce的KNN演算法實現

本篇部落格將圍繞Hadoop偽分佈安裝+MapReduce執行原理+基於MapReduce的KNN演算法實現這三個方面進行敘述。 (一)Hadoop偽分佈安裝 1、簡述Hadoop的安裝模式中–偽分佈模式與叢集模式的區別與聯絡. Hadoop的安裝方式有三種:本地模式,偽分佈模式

GAIL生成對抗模仿學習《Generative adversarial imitation learning》

前文是一些針對IRL,IL綜述性的解釋,後文是針對《Generative adversarial imitation learning》文章的理解及公式的推導。 通過深度強化學習,我們能夠讓機器人針對一個任務實現從0到1的學習,但是需要我們定義出reward函式,在很多複雜任