1. 程式人生 > >深入淺出聚類演算法之k-means演算法

深入淺出聚類演算法之k-means演算法

k-means是一個十分簡單的聚類演算法,它的思路非常簡明清晰,所以經常拿來當做教學。下面就來講述一下這個模型的細節操作。

內容

  • 模型原理
  • 模型收斂過程
  • 模型聚類個數
  • 模型侷限

1. 模型原理
將某一些資料分為不同的類別,在相同的類別中資料之間的距離應該都很近,也就是說離得越近的資料應該越相似,再進一步說明,資料之間的相似度與它們之間的歐式距離成反比。這就是k-means模型的假設。
有了這個假設,我們對將資料分為不同的類別的演算法就更明確了,儘可能將離得近的資料劃分為一個類別。不妨假設需要將資料{xi}聚為k類,經過聚類之後每個資料所屬的類別為{ti},而這k個聚類的中心為{μi}。於是定義如下的損失函式:
這裡寫圖片描述


k-means模型的目的是找尋最佳的{ti},使損失函式最小,之後就可以對聚類中心{μi}直接計算了。由此可見,它既是聚類的最終結果,也是需要估算的模型引數。


2. 模型收斂過程
在k-means的損失函式中存在兩個未知的引數:一個是每個資料所屬的類別{ti};一個是每個聚類的中心{μi}。這兩個未知的引數是相互依存的:如果知道每個資料的所屬類別,那麼類別的所有資料的平均值就是這個類別的中心;如果知道每個類別的中心,那麼就是計算資料與中心的距離,再根據距離的大小可以推斷出資料屬於哪一個類別。
根據這個思路,我們可以使用EM演算法(最大期望演算法)來估計模型的引數。具體操作如下:
1. 首先隨機生成k個聚類中心點
2. 根據聚類中心點,將資料分為k類。分類的原則是資料離哪個中心點近就將它分為哪一類別。
3. 再根據分好的類別的資料,重新計算聚類的類別中心點。
4. 不斷的重複2和3步,直到中心點不再變化。如下圖所示:
這裡寫圖片描述


3. 模型聚類個數
對於非監督學習,訓練資料是沒有標註變數的。那麼除了極少數的情況,我們都是無從知道資料應該被分為幾類。k-means演算法首先是隨機產生幾個聚類中心點,如果聚類中心點多了,會造成過擬合;如果聚類中心點少了,會造成欠擬合,所以聚類中心點是很關鍵的,在這裡使用誤差平方的變化和來評價模型預測結果好不好。當聚類個數小於真實值時,誤差平方和會下降的很快;當聚類個數超過真實值時,誤差平方和雖然會繼續下降,但是下降的速度會緩減,而這個轉折點就是最佳的聚類個數了。
這裡寫圖片描述

這裡寫圖片描述


4. 模型侷限
k-means是非常簡單的模型,但是它也有兩個明顯的缺陷,或者說它有兩種運用場景不能使用,第一是非均質的資料

,因為,模型使用歐氏距離衡量資料間的相似度,因此它要求資料在各個維度上都是均質。第二是不同類別內部方差不相同。模型假設不同類別的內部方差是大致相等的。
這裡寫圖片描述


下面使用鳶尾花資料集進行實戰。
引入包:

import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
import seaborn as sns
%matplotlib inline

觀察資料:

data = sns.load_dataset("iris")
data.head()
sns.pairplot(data, hue='species')

這裡寫圖片描述
觀察兩兩變數中聚類個數:

def trainModel(data, clusterNum):
    """
    使用KMeans對資料進行聚類
    """
    # max_iter表示EM演算法迭代次數,n_init表示K-means演算法迭代次數,algorithm="full"表示使用EM演算法。
    model = KMeans(n_clusters=clusterNum, max_iter=100, n_init=10, algorithm="full")
    model.fit(data)
    return model


def computeSSE(model, data):
    """
    計算聚類結果的誤差平方和
    """
    wdist = model.transform(data).min(axis=1)
    sse = np.sum(wdist ** 2)
    return sse

if __name__ == "__main__":
    col = [['petal_width', 'sepal_length'], ['petal_width', 'petal_length'], ['petal_width', 'sepal_width'], ['sepal_length', 'petal_length'],
['sepal_length', 'sepal_width'], ['petal_length', 'sepal_width']]

    for i in range(6):      
        fig = plt.figure(figsize=(8, 8), dpi=80)
        ax = fig.add_subplot(3, 2, i+1)
        sse = []
        for j in range(2, 6):
            model = trainModel(data[col[i]], j)
            sse.append(computeSSE(model, data[col[i]]))
        ax.plot(range(2,6), sse, 'k--', marker="o",
        markerfacecolor="r", markeredgecolor="k")
        ax.set_xticks([1,2,3,4,5,6])
        title = "clusterNum of %s and %s" % (col[i][0], col[i][1])
        ax.title.set_text(title)
        plt.show()

這裡寫圖片描述
通過這個圖,我們基本上可以判斷出應該分為三類,這也與實際情況是相同的。我們選擇一組進行視覺化聚類結果。

petal_data = data[['petal_width', 'petal_length']]
model = trainModel(petal_data, 3)
fig = plt.figure(figsize=(6,6), dpi=80)
ax = fig.add_subplot(1,1,1)
colors = ["r", "b", "g"]
ax.scatter(petal_data.petal_width, petal_data.petal_length, c=[colors[i] for i in model.labels_],
    marker="o", alpha=0.8)
ax.scatter(model.cluster_centers_[:, 0], model.cluster_centers_[:, 1], marker="*", c=colors, edgecolors="white",
    s=700., linewidths=2)
yLen = petal_data.petal_length.max() - petal_data.petal_length.min()
xLen = petal_data.petal_width.max() - petal_data.petal_width.min()
lens = max(yLen+1, xLen+1) / 2.
ax.set_xlim(petal_data.petal_width.mean()-lens, petal_data.petal_width.mean()+lens)
ax.set_ylim(petal_data.petal_length.mean()-lens, petal_data.petal_length.mean()+lens)
ax.set_ylabel("petal_length")
ax.set_xlabel("petal_width")

這裡寫圖片描述
這個效果是很好的,與實際的情況一致!