1. 程式人生 > >【機器學習】機器學習(十二、十三):K-means演算法、高斯混合模型

【機器學習】機器學習(十二、十三):K-means演算法、高斯混合模型

簡介:

        本節介紹STANFORD機器學習公開課中的第12、13集視訊中的演算法:K-means演算法、高斯混合模型(GMM)。(9、10、11集不進行介紹,略過了哈)

一、K-means演算法

        屬於無監督學習的聚類演算法,給定一組未標定的資料(輸入樣本),對其進行分類,假設可分為k個類。由於演算法比較直觀,故直接給出步驟和MATLAB程式碼。(k-means演算法在數學推導上是有意義的)


MATLAB程式碼:

%%
%k均值聚類
clear all;
close all;
%%
n=2;
m=200;
v0=randn(m/2,2)-1;
v1=randn(m/2,2)+1;
figure;
subplot(221);
hold on;
plot(v0(:,1),v0(:,2),'r.');
plot(v1(:,1),v1(:,2),'b.');
%axis([-5 5 -5 5]);
title('已分類資料');
hold off;

data=[v0;v1];
data=sortrows(data,1);
subplot(222);
plot(data(:,1),data(:,2),'g.');
title('未分類資料');
%axis([-5 5 -5 5]);
%%
[a b]=size(data);
m1=data(20,:);%隨機取重心點
m2=data(120,:);%隨機取重心點
k1=zeros(1,2);
k2=zeros(1,2);
n1=0;
n2=0;
subplot(223);hold on;
%axis([-5 5 -5 5]);
for t=1:10
    for i=1:a
        d1=pdist2(m1,data(i,:));
        d2=pdist2(m2,data(i,:));
        if (d1<d2)
            k1=k1+data(i,:);
            n1=n1+1;
            plot(data(i,1),data(i,2),'r.');
        else
            k2=k2+data(i,:);
            n2=n2+1;
            plot(data(i,1),data(i,2),'b.');
        end
    end
    m1=k1/n1;
    m2=k2/n2;
%     plot(m1(1,1),m1(1,2),'g.');
%     plot(m2(1,1),m2(1,2),'g.');
    k1=zeros(1,2);
    k2=zeros(1,2);
    n1=0;
    n2=0;
end
plot(m1(1,1),m1(1,2),'k*');
plot(m2(1,1),m2(1,2),'k*');
title('k-means聚類');
hold off;

輸出結果(未分類資料是由已分類資料去掉標籤,黑色※號表示聚類中心):

二、高斯混合模型(GMM)

           回想之前之前的高斯判別分析法(GDA),是通過計算樣本的後驗概率來進行判別,而後驗概率是通過假設多元高斯模型來計算得來的。高斯模型的引數:均值、協方差,是由已標定(分類)的樣本得來,所以可以看做是一種監督學習方法。

        在GMM模型(屬於無監督學習),給定未分類的m個樣本(n維特徵),假設可分為k個類,要求用GMM演算法對其進行分類。如果我們知道每個類的高斯引數,則可以向GDA演算法那樣計算出後驗概率進行判別。但遺憾的是,楊輸入的樣本未被標定,也就是說我們得不到高斯引數:均值、協方差。這就引出EM(Expectation Maximization Algorithm:期望最大化

)演算法。

        EM演算法的思想有點類似於k-means,就是通過迭代來得出最好的引數,有了這些引數就可以像GDA那樣做分類了。GMM及EM具體步驟如下:


MATLAB程式碼如下:

%%
%GMM演算法(高斯混合模型)soft assignment(軟劃分)
clear all;
close all;
%%
k=2;%聚類數
n=2;%維數
m=200;
% v0=randn(m/2,2)-1;
% v1=randn(m/2,2)+1;
v0=mvnrnd([1 1],[1 0;0 1],m/2);%生成正樣本1
v1=mvnrnd([4 4],[1 0;0 1],m/2);%生成負樣本0
figure;subplot(221);
hold on;
plot(v0(:,1),v0(:,2),'r.');
plot(v1(:,1),v1(:,2),'b.');
title('已分類資料');
hold off;
%%
data=[v0;v1];
data=sortrows(data,1);
subplot(222);
plot(data(:,1),data(:,2),'g.');
title('未分類資料');
%%
mu1=mean(data(1:50,:));
mu2=mean(data(100:180,:));
sigma1=cov(data(1:50,:));
sigma2=cov(data(100:180,:));
p=zeros(m,k);%概率
thresh=0.05;%迭代終止條件
iter=0;%記錄迭代次數
while(1)
    iter=iter+1;
    A1=1/(((2*pi)^(n/2))*((det(sigma1))^(1/2)));
    A2=1/(((2*pi)^(n/2))*((det(sigma2))^(1/2)));
    for i=1:m
        p(i,1)=A1*exp((-1/2)*(data(i,:)-mu1)*sigma1*(data(i,:)-mu1)');
        p(i,2)=A2*exp((-1/2)*(data(i,:)-mu2)*sigma2*(data(i,:)-mu2)');
        pp=sum(p(i,:));
        p(i,1)=p(i,1)/pp;%歸一化,樣本屬於某類的概率的總和為1
        p(i,2)=p(i,2)/pp;
    end
    sum1=zeros(n,n);
    sum2=zeros(n,n);
    for i=1:m
        sum1=sum1+p(i,1)*(data(i,:)-mu1)'*(data(i,:)-mu1);
        sum2=sum2+p(i,2)*(data(i,:)-mu2)'*(data(i,:)-mu2);
    end
    sigma1=sum1/sum(p(:,1));
    sigma2=sum2/sum(p(:,2));
    mu1_pre=mu1;
    mu2_pre=mu2;
    mu1=(p(:,1)'*data)/sum(p(:,1));
    mu2=(p(:,2)'*data)/sum(p(:,2));
    if ((pdist2(mu1_pre,mu1)<=thresh) || (pdist2(mu2_pre,mu2)<=thresh))
        break;
    end
end
%%
subplot(223);
hold on;
A1=1/(((2*pi)^(n/2))*((det(sigma1))^(1/2)));
A2=1/(((2*pi)^(n/2))*((det(sigma2))^(1/2)));
for i=1:m
    p(i,1)=A1*exp((-1/2)*(data(i,:)-mu1)*sigma1*(data(i,:)-mu1)');
    p(i,2)=A2*exp((-1/2)*(data(i,:)-mu2)*sigma2*(data(i,:)-mu2)');
    if p(i,1)>=p(i,2)
        plot(data(i,1),data(i,2),'r.');
    else
        plot(data(i,1),data(i,2),'b.');
    end
end
title('GMM分類');
hold off;
%完
輸出結果: