1. 程式人生 > >機器學習--K-means演算法

機器學習--K-means演算法

概述

聚類(K-mean)是一種典型的無監督學習。

採用距離作為相似性的評價指標,即認為兩個物件的距離越近,其相似度就越大。

該演算法認為類簇是由距離靠近的物件組成的,因此把得到緊湊且獨立的簇作為最終目標。

核心思想

通過迭代尋找k個類簇的一種劃分方案,使得用這k個類簇的均值來代表相應各類樣本時所得的總體誤差最小。

k個聚類具有以下特點:各聚類本身儘可能的緊湊,而各聚類之間儘可能的分開。

 k-means演算法的基礎是最小誤差平方和準則,

其代價函式是:

    

       式中,μc(i)表示第i個聚類的均值。

各類簇內的樣本越相似,其與該類均值間的誤差平方越小,對所有類所得到的誤差平方求和,即可驗證分為k類時,各聚類是否是最優的。

上式的代價函式無法用解析的方法最小化,只能有迭代的方法。

實踐

第一步,為了測試,使用指令碼生成1000個數據的資料集。

import numpy as np

#x = np.random.uniform(-6,6,2)
#print(str(x[0])+'\t'+str(x[1]))

with open('data.txt', 'w') as f:  # 以寫的方式開啟檔案
    for i in range(250):
        x1 = 2*np.random.randn(2)
        x2 = 2 * np.random.randn(2)
        x3 = 2 * np.random.randn(2)
        x4 = 2 * np.random.randn(2)
        strr1 = str(x1[0]+6)+'\t'+str(x1[1]+6)+'\n'
        strr2 = str(x2[0] + 6) + '\t' + str(x2[1] - 6) + '\n'
        strr3 = str(x3[0] - 6) + '\t' + str(x3[1] + 6) + '\n'
        strr4 = str(x4[0] - 6) + '\t' + str(x4[1] - 6) + '\n'
        strr = strr1+strr2+strr3+strr4
        f.write(strr)

檢視data.txt裡面的資料:

第二步K-means演算法:(解釋都在註釋中)

from numpy import *
import numpy as np
import matplotlib.pyplot as plt
import time
from threading import Thread
#載入資料

plt.ion() #開啟interactive mode

def loadDataSet(fileName):#解析檔案,按tab分割字元,得到一個浮點數字型別的矩陣
    dataMat = []#檔案的最後一個欄位是類別標籤
    fr = open(fileName)
    for line in fr.readlines():
        curLine = line.strip().split('\t')
        fltLine = list(map(float,curLine))#將每個元素轉成float型別
        dataMat.append(fltLine)
    dataMat = np.array(dataMat)
    return dataMat

def distEclud(vecA,vecB):
    return sqrt(sum(power(vecA - vecB,2)))#求兩個向量之間的距離

#構建聚簇中心,取k個(此例中為4)隨機質心
def randCent(dataSet,k):
    n = shape(dataSet)[1]
    centroids = mat(zeros((k,n)))#每個質心有n個座標值,總共k個質心
    for j in range(n):
        minJ = min(dataSet[:,j])
        maxJ = max(dataSet[:,j])
        rangeJ = float(maxJ-minJ)
        centroids[:,j] = minJ + rangeJ * random.rand(k,1)
    return centroids

#k-means聚類演算法
def kMeans(dataSet,k,distMeans = distEclud,createCent = randCent):
    m = shape(dataSet)[0]#獲取總資料量
    clusterAssment = mat(zeros((m,2)))#用於存放該樣本屬於哪類及質心距離
    #clusterAssment第一列存放該資料所屬的中心點,第二列是該資料到中心點的距離
    centroids = createCent(dataSet,k)#建立k箇中心點
    clusterChanged = True#用來判斷聚類是否收斂
    while clusterChanged:
        clusterChanged = False
        for i in range(m):#把每個資料點劃分到離它最近的中心點
            minDist = inf;minIndex = -1;#inf為無窮,minIndex為質心的代號
            for j in range(k):#分別計算各個點離k個質心的距離
                distJI = distMeans(centroids[j,:],dataSet[i,:])
                if distJI < minDist:#找到離這個點最近的質心
                    minDist = distJI;minIndex = j
            if clusterAssment[i,0] != minIndex: #只要有一個數據點發生變化,就說明分類還沒收斂,還要繼續
                clusterChanged = True

            clusterAssment[i,:] = minIndex,minDist**2 #並將第i個數據點的分配情況存入字典
        print(centroids)
        for cent in range(k):
            ptsInClust = dataSet[nonzero(clusterAssment[:, 0].A == cent)[0]]  # 取第一列等於cent的點
            centroids[cent, :] = mean(ptsInClust, axis=0)  # 算出這些資料的中心點,及當前更新後的質點

        ii = 0
        for cent in clusterAssment:
            #print(cent[0,0])
            if cent[0,0] == 0:
                x,y = dataSet[ii,0],dataSet[ii,1]
                plt.scatter(x, y, c='y')
            elif cent[0, 0] == 1:
                x, y = dataSet[ii, 0], dataSet[ii, 1]
                plt.scatter(x, y, c='r')

            elif cent[0, 0] == 2:
                x, y = dataSet[ii, 0], dataSet[ii, 1]
                plt.scatter(x, y, c='b')

            elif cent[0, 0] == 3:
                x, y = dataSet[ii, 0], dataSet[ii, 1]
                plt.scatter(x, y, c='g')
            ii = ii + 1

        plt.pause(1)
        plt.close()


    return centroids, clusterAssment

datMat = mat(loadDataSet('data.txt'))
myCentroids,clustAssing = kMeans(datMat,4)
print(myCentroids)
#print(clustAssing)

資料集比較簡單,資料很快就收斂。(資料集越複雜,越大,收斂會越慢)

 

可以看到最終得到的四個聚類的中心都在(+-6,+-6)附近,符合題設。

參考:https://www.cnblogs.com/ahu-lichang/p/7161613.html