機器學習--K-means演算法
阿新 • • 發佈:2018-11-08
概述
聚類(K-mean)是一種典型的無監督學習。
採用距離作為相似性的評價指標,即認為兩個物件的距離越近,其相似度就越大。
該演算法認為類簇是由距離靠近的物件組成的,因此把得到緊湊且獨立的簇作為最終目標。
核心思想
通過迭代尋找k個類簇的一種劃分方案,使得用這k個類簇的均值來代表相應各類樣本時所得的總體誤差最小。
k個聚類具有以下特點:各聚類本身儘可能的緊湊,而各聚類之間儘可能的分開。
k-means演算法的基礎是最小誤差平方和準則,
其代價函式是:
式中,μc(i)表示第i個聚類的均值。
各類簇內的樣本越相似,其與該類均值間的誤差平方越小,對所有類所得到的誤差平方求和,即可驗證分為k類時,各聚類是否是最優的。
上式的代價函式無法用解析的方法最小化,只能有迭代的方法。
實踐
第一步,為了測試,使用指令碼生成1000個數據的資料集。
import numpy as np #x = np.random.uniform(-6,6,2) #print(str(x[0])+'\t'+str(x[1])) with open('data.txt', 'w') as f: # 以寫的方式開啟檔案 for i in range(250): x1 = 2*np.random.randn(2) x2 = 2 * np.random.randn(2) x3 = 2 * np.random.randn(2) x4 = 2 * np.random.randn(2) strr1 = str(x1[0]+6)+'\t'+str(x1[1]+6)+'\n' strr2 = str(x2[0] + 6) + '\t' + str(x2[1] - 6) + '\n' strr3 = str(x3[0] - 6) + '\t' + str(x3[1] + 6) + '\n' strr4 = str(x4[0] - 6) + '\t' + str(x4[1] - 6) + '\n' strr = strr1+strr2+strr3+strr4 f.write(strr)
檢視data.txt裡面的資料:
第二步K-means演算法:(解釋都在註釋中)
from numpy import * import numpy as np import matplotlib.pyplot as plt import time from threading import Thread #載入資料 plt.ion() #開啟interactive mode def loadDataSet(fileName):#解析檔案,按tab分割字元,得到一個浮點數字型別的矩陣 dataMat = []#檔案的最後一個欄位是類別標籤 fr = open(fileName) for line in fr.readlines(): curLine = line.strip().split('\t') fltLine = list(map(float,curLine))#將每個元素轉成float型別 dataMat.append(fltLine) dataMat = np.array(dataMat) return dataMat def distEclud(vecA,vecB): return sqrt(sum(power(vecA - vecB,2)))#求兩個向量之間的距離 #構建聚簇中心,取k個(此例中為4)隨機質心 def randCent(dataSet,k): n = shape(dataSet)[1] centroids = mat(zeros((k,n)))#每個質心有n個座標值,總共k個質心 for j in range(n): minJ = min(dataSet[:,j]) maxJ = max(dataSet[:,j]) rangeJ = float(maxJ-minJ) centroids[:,j] = minJ + rangeJ * random.rand(k,1) return centroids #k-means聚類演算法 def kMeans(dataSet,k,distMeans = distEclud,createCent = randCent): m = shape(dataSet)[0]#獲取總資料量 clusterAssment = mat(zeros((m,2)))#用於存放該樣本屬於哪類及質心距離 #clusterAssment第一列存放該資料所屬的中心點,第二列是該資料到中心點的距離 centroids = createCent(dataSet,k)#建立k箇中心點 clusterChanged = True#用來判斷聚類是否收斂 while clusterChanged: clusterChanged = False for i in range(m):#把每個資料點劃分到離它最近的中心點 minDist = inf;minIndex = -1;#inf為無窮,minIndex為質心的代號 for j in range(k):#分別計算各個點離k個質心的距離 distJI = distMeans(centroids[j,:],dataSet[i,:]) if distJI < minDist:#找到離這個點最近的質心 minDist = distJI;minIndex = j if clusterAssment[i,0] != minIndex: #只要有一個數據點發生變化,就說明分類還沒收斂,還要繼續 clusterChanged = True clusterAssment[i,:] = minIndex,minDist**2 #並將第i個數據點的分配情況存入字典 print(centroids) for cent in range(k): ptsInClust = dataSet[nonzero(clusterAssment[:, 0].A == cent)[0]] # 取第一列等於cent的點 centroids[cent, :] = mean(ptsInClust, axis=0) # 算出這些資料的中心點,及當前更新後的質點 ii = 0 for cent in clusterAssment: #print(cent[0,0]) if cent[0,0] == 0: x,y = dataSet[ii,0],dataSet[ii,1] plt.scatter(x, y, c='y') elif cent[0, 0] == 1: x, y = dataSet[ii, 0], dataSet[ii, 1] plt.scatter(x, y, c='r') elif cent[0, 0] == 2: x, y = dataSet[ii, 0], dataSet[ii, 1] plt.scatter(x, y, c='b') elif cent[0, 0] == 3: x, y = dataSet[ii, 0], dataSet[ii, 1] plt.scatter(x, y, c='g') ii = ii + 1 plt.pause(1) plt.close() return centroids, clusterAssment datMat = mat(loadDataSet('data.txt')) myCentroids,clustAssing = kMeans(datMat,4) print(myCentroids) #print(clustAssing)
資料集比較簡單,資料很快就收斂。(資料集越複雜,越大,收斂會越慢)
可以看到最終得到的四個聚類的中心都在(+-6,+-6)附近,符合題設。