1. 程式人生 > >機器學習基礎(四十三)—— kd 樹( k 近鄰法的實現)

機器學習基礎(四十三)—— kd 樹( k 近鄰法的實現)

實現 k 近鄰法時,主要考慮的問題是如何對訓練資料進行快速 k 近鄰搜尋,這點在如下的兩種情況時,顯得尤為必要:

  • (1)特徵空間的維度大
  • (2)訓練資料的容量很大時

k 近鄰法的最簡單的實現是現行掃描(linear scan),這時需計算輸入例項與每一個訓練例項的距離,但訓練集很大時,計算非常耗時,這種方法是不可行的。

為了提高 k 近鄰搜尋的效率,可以考慮使用特殊的結構儲存訓練資料,以減少計算距離的次數。如本文介紹的 kd 樹(kd tree,k-dimensional tree)方法(這裡的 k 表示樣本集的維度,與 k近鄰的 k 無關)。

構造 kd 樹

kd 樹是一種對 k 維空間中的例項進行儲存以便對其進行快速檢索的樹形資料結構。kd樹是一種二叉樹,表示對 k 維空間的一次劃分(partition)。構造 kd 樹相當與不斷地用垂直於座標軸(沿著每一個屬性列,d

=1,2,,k)的超平面(hyperplane)將 k 維空間劃分。

通常依次選擇座標軸對空間劃分,選擇訓練例項點在選定座標軸上的中位數(median)為切分點。這樣得到的 kd 樹是平衡的,平衡的 kd 樹未必就是最優的。

class Node:
    def __init__(self, point):
                # point 表示切分點
        self.left = None
        self.right = None
        self.point = point
def median(l):
    m = len(l)/2
    return
l[m], m def build_kdtree(X, d, depth): k = len(X[0]) X = sorted(X, key=lambda x: x[d]) p, m = median(X) tree = Node(p) print p, depth if m > 0: tree.left = build_kdtree(X[:m], (d+1)%k, depth+1) if (m+1) < len(X): tree.right = build_kdtree(X[m+1:], (d+1
)%k, depth+1) return tree

搜尋 kd 樹

class Node:
    def __init__(self, point):
        self.left = None
        self.right = None
        self.parent = None
        self.point = point
    def set_left(self, left):
        self.left = left
        left.parent = self
    def set_right(self, right):
        self.right = right
        right.parent = self

def search_kdtree(tree, d, target): 
    k = len(tree[0])
    if tree.point[d] < target[d]:
        if tree.right != None:
            return search_kdtree(tree.right, (d+1)%k, target)
    else:
        if tree.left != None:
            return search_kdtree(tree.left, (d+1)%k, target)

    def update_best(t, best):
        if t == None: return 
        t = t.point
        d = euclidean(t, target)
        if  d < best[1]:
            best[1] = d
            best[0] = t

    best = [tree.point, float('inf')]

    while tree.parent != None:
        update_best(tree.parent.left, best)
        update_best(tree.parent.right, best)
        tree = tree.parent
    return best[0]

分析

如果例項點是隨機分佈的,kd搜尋樹的平均時間複雜度為 O(logN)N 表示訓練例項數。kd 樹搜尋更適用於訓練例項數遠大於空間維數時的 k 近鄰搜尋,當空間維數接近訓練例項數(非常畸形,也即接近線性的一顆不平衡的二叉樹)時,它的效率會迅速下降,幾乎接近線性掃描。

References

[1] k 近鄰法