K-Means演算法實現(Matlab)
K-Means演算法具體內容可以參考我部落格的相關文章,這裡只使用Matlab對其進行實現,其他內容不多贅述
K-Means演算法
1.生成隨機樣本點
首先利用 mvnrnd
函式生成3組滿足高斯分佈的資料,每組資料都是100*2的矩陣,也就相當於生成300個在座標軸上的樣本點
%% 第一組資料 mu1=[0 0];%均值 S1=[0.1 0 ; 0 0.1];%協方差 data1=mvnrnd(mu1,S1,100);%產生高斯分佈資料 %% 第二組資料 mu2=[-1.25 1.25]; S2=[0.1 0 ; 0 0.1]; data2=mvnrnd(mu2,S2,100); %% 第三組資料 mu3=[1.25 1.25]; S3=[0.1 0 ; 0 0.1]; data3=mvnrnd(mu3,S3,100);
mu1
、 mu2
、 mu3
是資料的均值,也就是你將每組點畫在座標軸上,其大致的中心位置座標,例如對於上面的三組資料,中心點就分別為(0,0),(-1.25,1.25),(1.25,1.25),畫在圖上效果如下圖

作圖程式碼如下:
%% 顯示資料 plot(data1(:,1),data1(:,2),'b+'); hold on; plot(data2(:,1),data2(:,2),'b+'); plot(data3(:,1),data3(:,2),'b+');
2.初始化各矩陣
首先我們要將三個100*2的矩陣合併為一個300*2的矩陣 data = [data1;data2;data2]
然後初始化聚類中心,生成N行2列的零矩陣,這裡的N是使用者輸入的想要聚為幾類
還有就是要把data矩陣拷貝一份,儘量在演算法執行過程中執行拷貝矩陣,而不去動data
%%初始化變數 %%初始化工作 data = [data1;data2;data3]; [m,n] = size(data); % m = 300,n = 2 center = zeros(N,n);% 初始化聚類中心,生成N行n列的零矩陣 pattern = data;% 將整個資料拷貝到pattern矩陣中
3.演算法核心
一開始隨機選取300個點中的N個點作為聚類中心(N是使用者輸入的聚類個數)。300個點分別計算到這N箇中心點那一個最短,就將該點分為第幾號。舉個例子:
設有一個點的座標是(0,0),分別有3箇中心點(2,2),(1,1),(3,3),經過計算,(0,0)到(1,1)的距離是最短的,因此將(0,0)這個點劃分為第2類
300個點全部劃分完以後,假設使用者輸入的N是3,劃分成60,90,150,然後計算60個點的中心點座標(只要將60個點的x座標加起來除以60,然後將y座標加起來除以60,就能得到中心點),70個點的中心座標,150個點的中心座標,設這三個中心座標為$(x_a,y_a)$,$(x_b,y_b)$,$(x_c,y_c)$,計算這三個中心點與之前隨機選的三個中心點的距離是否小於一個閾值,如果都小於,則說明分類成功;只要有一個不滿足,首先將這些新的中心座標替換原來的中心座標,然後重新分類
for x = 1 : N center(x,:) = data(randi(300,1),:); % 第一次隨機產生聚類中心 randi返回1*1的(1,300)的數 end while true distence = zeros(1,N);% 產生1行N列的零矩陣 num = zeros(1,N);% 產生1行N列的零矩陣 new_center = zeros(N,n); % 產生N行n列的零矩陣 %% 將所有的點打上標籤1 2 3...N for x = 1 : m for y = 1 : N distence(y) = norm(data(x,:) - center(y,:)); % norm函式計算到每個類的距離 end [~,temp] = min(distence); %求最小的距離 ~是距離值,temp是第幾個 pattern(x,n + 1) = temp; end k = 0; %% 將所有在同一類裡的點座標全部相加,計算新的中心座標 for y = 1 : N for x = 1 : m if pattern(x,n + 1) == y new_center(y,:) = new_center(y,:) + pattern(x,1:n); num(y) = num(y) + 1; end end new_center(y,:) = new_center(y,:) / num(y); if norm(new_center(y,:) - center(y,:)) < 0.1 k = k + 1; end end if k == N break; else center = new_center; end end [m, n] = size(pattern); %[m,n] = [300,3]
4.繪製聚類後的資料點圖
figure; hold on; for i = 1 : m if pattern(i,n) == 1 plot(pattern(i,1),pattern(i,2),'r*'); plot(center(1,1),center(1,2),'ko'); elseif pattern(i,n) == 2 plot(pattern(i,1),pattern(i,2),'g*'); plot(center(2,1),center(2,2),'ko'); elseif pattern(i,n) == 3 plot(pattern(i,1),pattern(i,2),'b*'); plot(center(3,1),center(3,2),'ko'); elseif pattern(i,n) == 4 plot(pattern(i,1),pattern(i,2),'y*'); plot(center(4,1),center(4,2),'ko'); else plot(pattern(i,1),pattern(i,2),'m*'); plot(center(5,1),center(5,2),'ko'); end end
完整程式碼
clear; clc; N = input('請設定聚類數目:');%設定聚類數目 %% 第一組資料 mu1=[0 0];%均值 S1=[0.1 0 ; 0 0.1];%協方差 data1=mvnrnd(mu1,S1,100);%產生高斯分佈資料 %% 第二組資料 mu2=[-1.25 1.25]; S2=[0.1 0 ; 0 0.1]; data2=mvnrnd(mu2,S2,100); %% 第三組資料 mu3=[1.25 1.25]; S3=[0.1 0 ; 0 0.1]; data3=mvnrnd(mu3,S3,100); %% 顯示資料 plot(data1(:,1),data1(:,2),'b+'); hold on; plot(data2(:,1),data2(:,2),'b+'); plot(data3(:,1),data3(:,2),'b+'); %%初始化工作 data = [data1;data2;data3]; [m,n] = size(data); % m = 300,n = 2 center = zeros(N,n);% 初始化聚類中心,生成N行n列的零矩陣 pattern = data;% 將整個資料拷貝到pattern矩陣中 %% 演算法 for x = 1 : N center(x,:) = data(randi(300,1),:); % 第一次隨機產生聚類中心 randi返回1*1的(1,300)的數 end while true distence = zeros(1,N);% 產生1行N列的零矩陣 num = zeros(1,N);% 產生1行N列的零矩陣 new_center = zeros(N,n); % 產生N行n列的零矩陣 %% 將所有的點打上標籤1 2 3...N for x = 1 : m for y = 1 : N distence(y) = norm(data(x,:) - center(y,:)); % norm函式計算到每個類的距離 end [~,temp] = min(distence); %求最小的距離 ~是距離值,temp是第幾個 pattern(x,n + 1) = temp; end k = 0; %% 將所有在同一類裡的點座標全部相加,計算新的中心座標 for y = 1 : N for x = 1 : m if pattern(x,n + 1) == y new_center(y,:) = new_center(y,:) + pattern(x,1:n); num(y) = num(y) + 1; end end new_center(y,:) = new_center(y,:) / num(y); if norm(new_center(y,:) - center(y,:)) < 0.1 k = k + 1; end end if k == N break; else center = new_center; end end [m, n] = size(pattern); %[m,n] = [300,3] %% 最後顯示聚類後的資料 figure; hold on; for i = 1 : m if pattern(i,n) == 1 plot(pattern(i,1),pattern(i,2),'r*'); plot(center(1,1),center(1,2),'ko'); elseif pattern(i,n) == 2 plot(pattern(i,1),pattern(i,2),'g*'); plot(center(2,1),center(2,2),'ko'); elseif pattern(i,n) == 3 plot(pattern(i,1),pattern(i,2),'b*'); plot(center(3,1),center(3,2),'ko'); elseif pattern(i,n) == 4 plot(pattern(i,1),pattern(i,2),'y*'); plot(center(4,1),center(4,2),'ko'); else plot(pattern(i,1),pattern(i,2),'m*'); plot(center(5,1),center(5,2),'ko'); end end
執行的GIF圖如下:

存在的問題以及改進方法
這只是一個比較簡單的K-Means聚類程式碼,其中可能存在兩個問題:
- 死迴圈
- 聚類不準確
第一個問題產生的原因很簡單,如果用筆算過K-Means就會知道,對於一個數據集,可能的聚類方式不止一種,並且存在確實無法達到所有的聚類中心差都小於閾值的情況。解決辦法是加一個變數 times
用於記錄執行了多少次while迴圈,當times達到一個很大的值而依舊沒有停止程式,可以判斷出現了死迴圈,乾脆直接輸出結果,不再計算。
第二個問題產生的效果圖如下

對於右邊的樣本集,我們用肉眼觀察很明顯聚類應該如紅框所示,但是使用K-Means聚類後得到的結果與預期差異較大,究其原因有很多種,具體解決辦法是將閾值減小,以達到更加精確的聚類