聚類演算法之Mean Shift
在 K-Means演算法 中,最終的聚類效果受初始的聚類中心的影響, K-Means++ 演算法的提出,為選擇較好的初始聚類中心提供了依據,但是演算法中, 聚類的類別個數k 仍需事先制定,對於類別個數事先未知的資料集,K-Means和K-Means++將很難對其精確求解,對此,有一些改進的演算法被提出來處理聚類個數k未知的情形。Mean Shift演算法,又被稱為均值漂移演算法,與K-Means演算法一樣,都是基於聚類中心的聚類演算法,不同的是,Mean Shift演算法不需要事先制定類別個數k。
Mean Shift的概念最早是由Fukunage在1975年提出的,在後來由Yizong Cheng對其進行擴充,主要提出了兩點的改進:定義了核函式,增加了權重係數。核函式的定義使得偏移值對偏移向量的貢獻隨之樣本與被偏移點的距離的不同而不同。權重係數使得不同樣本的權重不同。
Mean Shift演算法在很多領域都有成功應用,例如影象平滑、影象分割、物體跟蹤等,這些屬於人工智慧裡面模式識別或計算機視覺的部分;另外也包括常規的聚類應用。
-
影象平滑:影象最大質量下的畫素壓縮;
-
影象分割:跟影象平滑類似的應用,但最終是將可以平滑的影象進行分離已達到前後景或固定物理分割的目的;
-
目標跟蹤:例如針對監控視訊中某個人物的動態跟蹤;
-
常規聚類,如使用者聚類等。
Mean Shift演算法理論
Mean Shift向量
對於給定的維空間 中的n個樣本點
,則對於x點,其Mean Shift向量的基本形式為:
其中, 指的是一個半徑為h的高維球區域,如上圖中的圓形區域。
的定義為:
裡面所有點與圓心為起點形成的向量相加的結果就是Mean shift向量。下圖黃色箭頭就是(Mean Shift向量)。
對於Mean Shift演算法,是一個迭代的步驟,即先算出當前點的偏移均值,將該點移動到此偏移均值,然後以此為新的起始點,繼續移動,直到滿足最終的條件。
Mean-Shift 聚類就是對於集合中的每一個元素,對它執行下面的操作:把該元素移動到它鄰域中所有元素的特徵值的均值的位置,不斷重複直到收斂。準確的說,不是真正移動元素,而是把該元素與它的收斂位置的元素標記為同一類。
如上的均值漂移向量的求解方法存在一個問題,即在 的區域內,每一個樣本點x對樣本X的共享是一樣的。而實際中,每一個樣本點x對樣本X的貢獻是不一樣的,這樣的共享可以通過核函式進行度量。
核函式
在Mean Shift演算法中引入核函式的目的是使得隨著樣本與被偏移點的距離不同,其偏移量對均值偏移向量的貢獻也不同。核函式是機器學習中常用的一種方式。核函式的定義如下所示:
X 表示一個d維的歐式空間,x 是該空間中的一個點 ,其中,x的模
,R表示實數域,如果一個函式K:X→R存在一個剖面函式k:[0,∞]→R,即
並且滿足:
-
k是非負的
-
k是非增的
-
k是分段連續的
那麼,函式K(x)就稱為核函式。
核函式 有很多,下圖中表示的每一個曲線都為一個核函式。
常用的核函式有高斯核函式。高斯核函式如下所示:
其中,h稱為頻寬(bandwidth),不同頻寬的核函式如下圖所示:
從高斯函式的影象可以看出,當頻寬h一定時,樣本點之間的距離越近,其核函式的值越大,當樣本點之間的距離相等時,隨著高斯函式的頻寬h的增加,核函式的值在減小。
高斯核函式的Python實現:
# -*- coding:utf-8 -*- import numpy as np import math def gaussian_kernel(distance, bandwidth): ''' 高斯核函式 :param distance: 歐氏距離計算函式 :param bandwidth: 核函式的頻寬 :return: 高斯函式值 ''' m = np.shape(distance)[0]# 樣本個數 right = np.mat(np.zeros((m, 1)))# m * 1 矩陣 for i in range(m): right[i, 0] = (-0.5 * distance[i] * distance[i].T) / (bandwidth * bandwidth) right[i, 0] = np.exp(right[i, 0]) left = 1 / (bandwidth * math.sqrt(2 * math.pi)) gaussian_val = left * right return gaussian_val
引入核函式的Mean Shift向量
假設在半徑為h的範圍 範圍內,為了使得每一個樣本點x對於樣本X的共享不一樣,向基本的Mean Shift向量形式中增加核函式,得到如下改進的Mean Shift向量形式:
其中, 為核函式。通常,可以取
為整個資料集範圍。
計算時考慮距離的影響,同時也可以認為在所有的樣本點X中,重要性並不一樣,因此對每個樣本還引入一個權重係數。如此以來就可以把Mean Shift形式擴充套件為:
其中, 是一個賦給取樣點的權重。
聚類動畫演示
Mean Shift的程式碼實現
演算法的 Python 實現
import numpy as np import math MIN_DISTANCE = 0.00001# 最小誤差 def euclidean_dist(pointA, pointB): # 計算pointA和pointB之間的歐式距離 total = (pointA - pointB) * (pointA - pointB).T return math.sqrt(total) def gaussian_kernel(distance, bandwidth): ''' 高斯核函式 :param distance: 歐氏距離計算函式 :param bandwidth: 核函式的頻寬 :return: 高斯函式值 ''' m = np.shape(distance)[0]# 樣本個數 right = np.mat(np.zeros((m, 1))) for i in range(m): right[i, 0] = (-0.5 * distance[i] * distance[i].T) / (bandwidth * bandwidth) right[i, 0] = np.exp(right[i, 0]) left = 1 / (bandwidth * math.sqrt(2 * math.pi)) gaussian_val = left * right return gaussian_val def shift_point(point, points, kernel_bandwidth): '''計算均值漂移點 :param point: 需要計算的點 :param points: 所有的樣本點 :param kernel_bandwidth: 核函式的頻寬 :return: point_shifted:漂移後的點 ''' points = np.mat(points) m = np.shape(points)[0]# 樣本個數 # 計算距離 point_distances = np.mat(np.zeros((m, 1))) for i in range(m): point_distances[i, 0] = euclidean_dist(point, points[i]) # 計算高斯核 point_weights = gaussian_kernel(point_distances, kernel_bandwidth) # 計算分母 all = 0.0 for i in range(m): all += point_weights[i, 0] # 均值偏移 point_shifted = point_weights.T * points / all return point_shifted def group_points(mean_shift_points): '''計算所屬的類別 :param mean_shift_points:漂移向量 :return: group_assignment:所屬類別 ''' group_assignment = [] m, n = np.shape(mean_shift_points) index = 0 index_dict = {} for i in range(m): item = [] for j in range(n): item.append(str(("%5.2f" % mean_shift_points[i, j]))) item_1 = "_".join(item) if item_1 not in index_dict: index_dict[item_1] = index index += 1 for i in range(m): item = [] for j in range(n): item.append(str(("%5.2f" % mean_shift_points[i, j]))) item_1 = "_".join(item) group_assignment.append(index_dict[item_1]) return group_assignment def train_mean_shift(points, kernel_bandwidth=2): '''訓練Mean Shift模型 :param points: 特徵資料 :param kernel_bandwidth: 核函式頻寬 :return: points:特徵點 mean_shift_points:均值漂移點 group:類別 ''' mean_shift_points = np.mat(points) max_min_dist = 1 iteration = 0 m = np.shape(mean_shift_points)[0]# 樣本的個數 need_shift = [True] * m# 標記是否需要漂移 # 計算均值漂移向量 while max_min_dist > MIN_DISTANCE: max_min_dist = 0 iteration += 1 print("iteration : " + str(iteration)) for i in range(0, m): # 判斷每一個樣本點是否需要計算偏置均值 if not need_shift[i]: continue p_new = mean_shift_points[i] p_new_start = p_new p_new = shift_point(p_new, points, kernel_bandwidth)# 對樣本點進行偏移 dist = euclidean_dist(p_new, p_new_start)# 計算該點與漂移後的點之間的距離 if dist > max_min_dist:# 記錄是有點的最大距離 max_min_dist = dist if dist < MIN_DISTANCE:# 不需要移動 need_shift[i] = False mean_shift_points[i] = p_new # 計算最終的group group = group_points(mean_shift_points)# 計算所屬的類別 return np.mat(points), mean_shift_points, group
以上程式碼實現了基本的流程,但是執行效率很慢,正式使用時建議使用scikit-learn庫中的MeanShift。
scikit-learn MeanShift 演示
import numpy as np from sklearn.cluster import MeanShift, estimate_bandwidth data = [] f = open("k_means_sample_data.txt", 'r') for line in f: data.append([float(line.split(',')[0]), float(line.split(',')[1])]) data = np.array(data) # 通過下列程式碼可自動檢測bandwidth值 # 從data中隨機選取1000個樣本,計算每一對樣本的距離,然後選取這些距離的0.2分位數作為返回值,當n_samples很大時,這個函式的計算量是很大的。 bandwidth = estimate_bandwidth(data, quantile=0.2, n_samples=1000) print(bandwidth) # bin_seeding設定為True就不會把所有的點初始化為核心位置,從而加速演算法 ms = MeanShift(bandwidth=bandwidth, bin_seeding=True) ms.fit(data) labels = ms.labels_ cluster_centers = ms.cluster_centers_ # 計算類別個數 labels_unique = np.unique(labels) n_clusters = len(labels_unique) print("number of estimated clusters : %d" % n_clusters) # 畫圖 import matplotlib.pyplot as plt from itertools import cycle plt.figure(1) plt.clf()# 清楚上面的舊圖形 # cycle把一個序列無限重複下去 colors = cycle('bgrcmyk') for k, color in zip(range(n_clusters), colors): # current_member表示標籤為k的記為true 反之false current_member = labels == k cluster_center = cluster_centers[k] # 畫點 plt.plot(data[current_member, 0], data[current_member, 1], color + '.') #畫圈 plt.plot(cluster_center[0], cluster_center[1], 'o', markerfacecolor=color,#圈內顏色 markeredgecolor='k',#圈邊顏色 markersize=14)#圈大小 plt.title('Estimated number of clusters: %d' % n_clusters) plt.show()
執行效果:
scikit-learn MeanShift原始碼分析
原始碼地址: https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/cluster/mean_shift_.py
def mean_shift(X, bandwidth=None, seeds=None, bin_seeding=False, min_bin_freq=1, cluster_all=True, max_iter=300, n_jobs=1): """Perform mean shift clustering of data using a flat kernel. Read more in the :ref:`User Guide <mean_shift>`. Parameters ---------- X : array-like, shape=[n_samples, n_features] Input data. bandwidth : float, optional Kernel bandwidth. If bandwidth is not given, it is determined using a heuristic based on the median of all pairwise distances. This will take quadratic time in the number of samples. The sklearn.cluster.estimate_bandwidth function can be used to do this more efficiently. seeds : array-like, shape=[n_seeds, n_features] or None Point used as initial kernel locations. If None and bin_seeding=False, each data point is used as a seed. If None and bin_seeding=True, see bin_seeding. bin_seeding : boolean, default=False If true, initial kernel locations are not locations of all points, but rather the location of the discretized version of points, where points are binned onto a grid whose coarseness corresponds to the bandwidth. Setting this option to True will speed up the algorithm because fewer seeds will be initialized. Ignored if seeds argument is not None. min_bin_freq : int, default=1 To speed up the algorithm, accept only those bins with at least min_bin_freq points as seeds. cluster_all : boolean, default True If true, then all points are clustered, even those orphans that are not within any kernel. Orphans are assigned to the nearest kernel. If false, then orphans are given cluster label -1. max_iter : int, default 300 Maximum number of iterations, per seed point before the clustering operation terminates (for that seed point), if has not converged yet. n_jobs : int The number of jobs to use for the computation. This works by computing each of the n_init runs in parallel. If -1 all CPUs are used. If 1 is given, no parallel computing code is used at all, which is useful for debugging. For n_jobs below -1, (n_cpus + 1 + n_jobs) are used. Thus for n_jobs = -2, all CPUs but one are used. .. versionadded:: 0.17 Parallel Execution using *n_jobs*. Returns ------- cluster_centers : array, shape=[n_clusters, n_features] Coordinates of cluster centers. labels : array, shape=[n_samples] Cluster labels for each point. Notes ----- See examples/cluster/plot_mean_shift.py for an example. """ #沒有定義bandwidth執行函式estimate_bandwidth估計頻寬 if bandwidth is None: bandwidth = estimate_bandwidth(X, n_jobs=n_jobs) #頻寬小於0就報錯 elif bandwidth <= 0: raise ValueError("bandwidth needs to be greater than zero or None,\ got %f" % bandwidth) #如果沒有設定種子 if seeds is None: #通過get_bin_seeds選取種子 #min_bin_freq指定最少的種子數目 if bin_seeding: seeds = get_bin_seeds(X, bandwidth, min_bin_freq) #把所有點設為種子 else: seeds = X #根據shape得到樣本數量和特徵數量 n_samples, n_features = X.shape #中心強度字典 鍵為點 值為強度 center_intensity_dict = {} #近鄰搜尋 fit的返回值為 #radius意思是半徑 表示引數空間的範圍 #用作於radius_neighbors 可以理解為在半徑範圍內找鄰居 nbrs = NearestNeighbors(radius=bandwidth, n_jobs=n_jobs).fit(X) #並行地在所有種子上執行迭代 #all_res為所有種子的迭代完的中心以及周圍的鄰居數 # execute iterations on all seeds in parallel all_res = Parallel(n_jobs=n_jobs)( delayed(_mean_shift_single_seed) (seed, X, nbrs, max_iter) for seed in seeds) #遍歷所有結果 # copy results in a dictionary for i in range(len(seeds)): #只有這個點的周圍沒有鄰居才會出現None的情況 if all_res[i] is not None: #一箇中心點對應一個強度(周圍鄰居個數) center_intensity_dict[all_res[i][0]] = all_res[i][1] #要是一個符合要求的點都沒有,就說明bandwidth設定得太小了 if not center_intensity_dict: # nothing near seeds raise ValueError("No point was within bandwidth=%f of any seed." " Try a different seeding strategy \ or increase the bandwidth." % bandwidth) # POST PROCESSING: remove near duplicate points # If the distance between two kernels is less than the bandwidth, # then we have to remove one because it is a duplicate. Remove the # one with fewer points. #按照強度來排序 #dict.items()返回值形式為[(key1,value1),(key2,value2)...] #reverse為True表示由大到小 #key的lambda表示式用來指定用作比較的部分為value sorted_by_intensity = sorted(center_intensity_dict.items(), key=lambda tup: tup[1], reverse=True) #單獨把排好序的點分出來 sorted_centers = np.array([tup[0] for tup in sorted_by_intensity]) #返回長度和點數量相等的bool型別array unique = np.ones(len(sorted_centers), dtype=np.bool) #在這些點裡再來一次找鄰居 nbrs = NearestNeighbors(radius=bandwidth, n_jobs=n_jobs).fit(sorted_centers) #enumerate返回的是index,value #還是類似於之前的找鄰居 不過這次是為了剔除相近的點 就是去除重複的中心 #因為是按強度由大到小排好序的 所以優先將靠前的當作確定的中心 for i, center in enumerate(sorted_centers): if unique[i]: neighbor_idxs = nbrs.radius_neighbors([center], return_distance=False)[0] #中心的鄰居不能作為候選 unique[neighbor_idxs] = 0 #因為這個範圍內肯定包含自己,所以要單獨標為1 unique[i] = 1# leave the current point as unique #把篩選過後的中心拿出來 就是最終的聚類中心 cluster_centers = sorted_centers[unique] #分配標籤:最近的類就是這個點的類 # ASSIGN LABELS: a point belongs to the cluster that it is closest to #把中心放進去 用kneighbors來找鄰居 #n_neighbors標為1 使找到的鄰居數為1 也就成了標籤 nbrs = NearestNeighbors(n_neighbors=1, n_jobs=n_jobs).fit(cluster_centers) #labels用來存放標籤 labels = np.zeros(n_samples, dtype=np.int) #所有點帶進去求 distances, idxs = nbrs.kneighbors(X) #cluster_all為True表示所有的點都會被聚類 if cluster_all: #flatten可以簡單理解如下 #>>> np.array([[[[1,2]],[[3,4]],[[5,6]]]]).flatten() #array([1, 2, 3, 4, 5, 6]) labels = idxs.flatten() #為False就把距離大於bandwidth的點類別標為-1 else: #先全標-1 labels.fill(-1) #距離小於bandwidth的標False bool_selector = distances.flatten() <= bandwidth #標True的才能參與聚類 labels[bool_selector] = idxs.flatten()[bool_selector] #返回的結果為聚類中心和每個樣本的標籤 return cluster_centers, labels
# separate function for each seed's iterative loop def _mean_shift_single_seed(my_mean, X, nbrs, max_iter): #對於每個種子,梯度上升,直到收斂或者到達max_iter次迭代次數 # For each seed, climb gradient until convergence or max_iter bandwidth = nbrs.get_params()['radius'] #表示收斂時的閾值 stop_thresh = 1e-3 * bandwidth# when mean has converged #記錄完成的迭代次數 completed_iterations = 0 while True: #radius_neighbors尋找my_mean周圍的鄰居 #i_nbrs是符合要求的鄰居的下標 # Find mean of points within bandwidth i_nbrs = nbrs.radius_neighbors([my_mean], bandwidth, return_distance=False)[0] #根據下標找點 points_within = X[i_nbrs] #找不到點就跳出迭代 if len(points_within) == 0: break# Depending on seeding strategy this condition may occur #儲存舊的均值 my_old_mean = my_mean# save the old mean #移動均值,這就是mean-shift名字的由來,每一步的迭代就是計算新的均值點 my_mean = np.mean(points_within, axis=0) #用歐幾里得範數與閾值進行比較判斷收斂 或者 #判斷迭代次數達到上限 # If converged or at max_iter, adds the cluster if (extmath.norm(my_mean - my_old_mean) < stop_thresh or completed_iterations == max_iter): #返回收斂時的均值中心和周圍鄰居個數 #tuple表示轉換成元組 因為之後的center_intensity_dict鍵不能為列表 return tuple(my_mean), len(points_within) #迭代次數增加 completed_iterations += 1
參考資料: