1. 程式人生 > >周志華《機器學習》Ch9. 聚類:k-means演算法的python實現

周志華《機器學習》Ch9. 聚類:k-means演算法的python實現

理論

k-means方法是一種常用的聚類方法,其目標是最小化

\sum\limits_{i=1}^k\sum_{j=1}^{m_i}||x_{ij}-\mu_i||

其中\mu_i是第i個簇的中心。直接優化上式有難度,故k-means演算法採用一種近似方法。

簡單來說,k-means演算法由兩個步驟迴圈組成:

1. 計算每個sample到各個簇中心的距離,將該sample的類標賦為距離最近的簇的類標;

2. 按照sample的類標重新計算各個簇中心

k-means演算法有兩個輸入引數需要使用者指定,一個是簇的個數,另一個是迴圈次數

程式碼

# -*- coding: utf-8 -*-
"""
k-means algorithm
From 'Machine Learning, Zhihua Zhou' Ch9
Model: k-means clustering algorithm
Dataset: P202 watermelon_4.0 (watermelon_4.0.npy)

@author: weiyx15
"""

import numpy as np
import matplotlib.pyplot as plt

class kMeans:
    def load_data(self, filename):
        self.x = np.load(filename)
        self.m = self.x.shape[0]            # sample number
        self.d = self.x.shape[1]            # feature dimension
        
    def __init__(self, kk, repeat):
        self.load_data('watermelon_4.0.npy')
        self.k = kk             # cluster number
        self.rep = repeat       # iteration timess
        self.P = np.zeros((self.k, self.d)) # cluster center vector
        for i in range(self.k):             # initialize vector P
            self.P[i, :] = self.x[int(self.m/self.k*i), :]
        self.L = np.zeros((self.m,),dtype=int)# cluster labels
        
    def getLabel(self, xi): # INPUT a sample, OUTPUT its label
        dmin = np.inf
        jmin = 0
        for j in range(self.k):
            dij = np.linalg.norm(xi - self.P[j, :])
            if dij < dmin:
                dmin = dij
                jmin = j
        return jmin
            
    
    def train(self):
        for r in range(self.rep):
            cnt = np.zeros((self.k,))
            for i in range(self.m):
                self.L[i] = self.getLabel(self.x[i, :])
                cnt[self.L[i]] = cnt[self.L[i]] + 1
            for i in range(self.k):
                S = sum(self.x[self.L==i, :])
                self.P[i, :] = S / cnt[i]
            
    def plot_data(self):
        color = ['r', 'b', 'y']
        plt.figure()
        for i in range(self.k):
            plt.plot(self.P[i,0], self.P[i,1], color[i%self.k]+'*')
        for i in range(self.k):
            plt.plot(self.x[self.L == i, 0], self.x[self.L == i, 1],\
                     color[i%self.k]+'.')
    
if __name__ == '__main__':
    km = kMeans(3, 10)
    km.train()
    km.plot_data()

結果

西瓜資料集4.0用k-means演算法3聚類10次迭代後的結果如下圖所示,其中"*"表示簇中心。