機器學習基礎(四十三)—— kd 樹( k 近鄰法的實現)
阿新 • • 發佈:2018-12-30
實現 k 近鄰法時,主要考慮的問題是如何對訓練資料進行快速 k 近鄰搜尋,這點在如下的兩種情況時,顯得尤為必要:
- (1)特徵空間的維度大
- (2)訓練資料的容量很大時
k 近鄰法的最簡單的實現是現行掃描(linear scan),這時需計算輸入例項與每一個訓練例項的距離,但訓練集很大時,計算非常耗時,這種方法是不可行的。
為了提高 k 近鄰搜尋的效率,可以考慮使用特殊的結構儲存訓練資料,以減少計算距離的次數。如本文介紹的 kd 樹(kd tree,k-dimensional tree)方法(這裡的 k 表示樣本集的維度,與 k近鄰的 k 無關)。
構造 kd 樹
kd 樹是一種對 k 維空間中的例項進行儲存以便對其進行快速檢索的樹形資料結構。kd樹是一種二叉樹,表示對 k 維空間的一次劃分(partition)。構造 kd 樹相當與不斷地用垂直於座標軸(沿著每一個屬性列, =1,2,…,k
通常依次選擇座標軸對空間劃分,選擇訓練例項點在選定座標軸上的中位數(median)為切分點。這樣得到的 kd 樹是平衡的,平衡的 kd 樹未必就是最優的。
class Node:
def __init__(self, point):
# point 表示切分點
self.left = None
self.right = None
self.point = point
def median(l):
m = len(l)/2
return l[m], m
def build_kdtree(X, d, depth):
k = len(X[0])
X = sorted(X, key=lambda x: x[d])
p, m = median(X)
tree = Node(p)
print p, depth
if m > 0:
tree.left = build_kdtree(X[:m], (d+1)%k, depth+1)
if (m+1) < len(X):
tree.right = build_kdtree(X[m+1:], (d+1 )%k, depth+1)
return tree
搜尋 kd 樹
class Node:
def __init__(self, point):
self.left = None
self.right = None
self.parent = None
self.point = point
def set_left(self, left):
self.left = left
left.parent = self
def set_right(self, right):
self.right = right
right.parent = self
def search_kdtree(tree, d, target):
k = len(tree[0])
if tree.point[d] < target[d]:
if tree.right != None:
return search_kdtree(tree.right, (d+1)%k, target)
else:
if tree.left != None:
return search_kdtree(tree.left, (d+1)%k, target)
def update_best(t, best):
if t == None: return
t = t.point
d = euclidean(t, target)
if d < best[1]:
best[1] = d
best[0] = t
best = [tree.point, float('inf')]
while tree.parent != None:
update_best(tree.parent.left, best)
update_best(tree.parent.right, best)
tree = tree.parent
return best[0]
分析
如果例項點是隨機分佈的,kd搜尋樹的平均時間複雜度為
References
[1] k 近鄰法