1. 程式人生 > >K-means原理及Matlab實現

K-means原理及Matlab實現

前言

作為勵志在機器學習界闖出一片天地的小女子,在此整理學習到的機器學習方法,並使用Matlab及Python實現。希望可以和大家互相交流和探討。
聯絡方式:[email protected]

K-means原理

k-means應該是入門機器學習最早接觸的演算法之一了,它使用簡單富有美感的演算法深刻地表達了教機器學習的思想,其中蘊含的EM思想會在後續的博文中詳細講解。

從它的名字上解析一下,k代表了你要將資料分為幾類(也就是後文提到的seed個數,別急,後面你會深刻理解),而means即為平均值,這也是此演算法的核心。

巨集觀上來看它屬於無監督學習,(無監督學習指:資料僅給出了特徵值,未給出資料的標籤,監督學習則同時給出了特徵值和標籤)

它具體的思想是什麼呢,容我細細道來。

假設我們要將以下情況的點分開:

這裡寫圖片描述

在這個例子中,我們用肉眼可以觀察到這個資料大致可以分成兩類。而需要注意的是在實際情況的應用上,可以直接觀察到類別個數基本是不可能的,主要原因有以下兩點:

  1. 實際資料特徵向量的維數很高,實現視覺化是十分困難的。
  2. 實際資料通常耦合程度高,沒有清晰的分界線。

所以在這裡需要提到一個題外話,在現實生活中往往是根據實際需要確定分類類別。舉個栗子:你是一家制衣廠的老闆,你有一大堆使用者身高體重肩寬腰圍等等的資料,你希望可以對使用者群體分個類,來確定s,m,l碼的衣服分別應該適合多大身材維度的使用者。這裡就可以使用kmeans。在這裡,你就按照實際情況直接將k設定成了3。

好了,聊了一些閒話,現在重回主題,我們希望將上圖資料進行分類,並且確定了k=2,即分成兩類。接下來,我們在所有資料點中隨機選取兩個種子(seed),播種下這兩顆種子,一切交給機器來學習吧!

這裡寫圖片描述

所有點將會和兩個seed進行比較,和哪個更加相似就加入哪個seed的陣營。

這裡需要引入一個相似的概念,在數學上,表徵相似程度的引數有許多:距離,相關係數等等。本文程式碼採用簡單的歐式距離,有心的小夥伴可以嘗試不同的方法進行嘗試。

這裡寫圖片描述

上圖中所有的點都已經確定了自己的紅藍陣營。此時重新計算seed值,即所有紅色點的特徵值求平均作為新的紅seed,所有藍色點的特徵值求平均作為新的藍seed。

這裡寫圖片描述

不斷重複以上的過程,即可完成最終分類。

這裡寫圖片描述

這裡寫圖片描述

Matlab實現

這是一個實現kmeans的函式:

function [seed_new all_data] = kmeans_f(data,k)

% % 輸入data是所有原始資料構成的陣列,k是選定的分類個數
% % 輸出seed_new是最終seed,all_data是最終data的特徵和label 值(第一列為label)

[m,n]=size(data);
length=m;
feature_number=n;

r=zeros(k);

% % 根據k初始化seed的index
r=randperm(length,k);

% %儲存所有的seed
seed_new=zeros(k,feature_number);
for j=1:k
    seed_new(j,:)=data(r(j),:);
end

all_data=zeros(length,feature_number+1);
% % % while 1是相當於do while迴圈
while 1
    seed_old=seed_new;
    for j=1:length
        choosen_point=data(j,:);
        dist_list=zeros(k,1);
        for i=1:k
            seed=seed_old(i,:);
            dist=norm(choosen_point-seed);
            dist_list(i,:)=dist;
        end

% %     返回最小距離的cluster index
        cluster_index=find(dist_list==min(dist_list));
% %         有時返回好幾個值
        cluster_index_point=cluster_index(1);

% %  將j點寫入第cluster-index的類中,all_data是feature_number+1維陣列,其中增加了第一列為其屬於cluster的編號。

     all_data(j,1)=cluster_index_point;
     all_data(j,2:end)=choosen_point;

    end

% % 計算得到新的seed矩陣,注意一定要按照cluster-index的順序排列
    seed_new_sum=zeros(k,feature_number);
    seed_number=zeros(k,1);
    for i=1:length
        for j=1:k
          if all_data(i,1)==j
              seed_new_sum(j,:)=seed_new_sum(j,:)+all_data(i,2:end);
% %               統計每個cluster中點的個數
              seed_number(j,:)=seed_number(j,:)+1;

          end
        end
    end
% %    計算出新的seed
   seed_new=zeros(k,feature_number);
   for i=1:feature_number
       seed_new(:,i)=seed_new_sum(:,i)./seed_number;
   end

% % 跳出迴圈的條件是seed的改變數非常小。 
   judge=norm(seed_new-seed_old)
    if judge<=0.01
        break;
    end
end
end

對這個函式進行測試:

clear all;
close all;
clc;

%第一類資料
mu1=[0 0 0];  %均值
S1=[0.3 0 0;0 0.35 0;0 0 0.3];  %協方差
data1=mvnrnd(mu1,S1,100);   %產生高斯分佈資料

% %第二類資料
mu2=[1.25 1.25 1.25];
S2=[0.3 0 0;0 0.35 0;0 0 0.3];
data2=mvnrnd(mu2,S2,100);

% %第三個類資料
mu3=[-1.25 1.25 -1.25];
S3=[0.3 0 0;0 0.35 0;0 0 0.3];
data3=mvnrnd(mu3,S3,100);

% %顯示資料
plot3(data1(:,1),data1(:,2),data1(:,3),'+');
hold on;
plot3(data2(:,1),data2(:,2),data2(:,3),'+');
plot3(data3(:,1),data3(:,2),data3(:,3),'+');
grid on;

測試得到的結果如下:

原始資料:

原始資料(未分類)

分類後的資料:

這裡寫圖片描述

結語

感謝大家看到這裡,歡迎隨時溝通交流~