【機器學習】機器學習(十二、十三):K-means演算法、高斯混合模型
阿新 • • 發佈:2019-01-09
簡介:
本節介紹STANFORD機器學習公開課中的第12、13集視訊中的演算法:K-means演算法、高斯混合模型(GMM)。(9、10、11集不進行介紹,略過了哈)
一、K-means演算法
屬於無監督學習的聚類演算法,給定一組未標定的資料(輸入樣本),對其進行分類,假設可分為k個類。由於演算法比較直觀,故直接給出步驟和MATLAB程式碼。(k-means演算法在數學推導上是有意義的)
MATLAB程式碼:
%% %k均值聚類 clear all; close all; %% n=2; m=200; v0=randn(m/2,2)-1; v1=randn(m/2,2)+1; figure; subplot(221); hold on; plot(v0(:,1),v0(:,2),'r.'); plot(v1(:,1),v1(:,2),'b.'); %axis([-5 5 -5 5]); title('已分類資料'); hold off; data=[v0;v1]; data=sortrows(data,1); subplot(222); plot(data(:,1),data(:,2),'g.'); title('未分類資料'); %axis([-5 5 -5 5]); %% [a b]=size(data); m1=data(20,:);%隨機取重心點 m2=data(120,:);%隨機取重心點 k1=zeros(1,2); k2=zeros(1,2); n1=0; n2=0; subplot(223);hold on; %axis([-5 5 -5 5]); for t=1:10 for i=1:a d1=pdist2(m1,data(i,:)); d2=pdist2(m2,data(i,:)); if (d1<d2) k1=k1+data(i,:); n1=n1+1; plot(data(i,1),data(i,2),'r.'); else k2=k2+data(i,:); n2=n2+1; plot(data(i,1),data(i,2),'b.'); end end m1=k1/n1; m2=k2/n2; % plot(m1(1,1),m1(1,2),'g.'); % plot(m2(1,1),m2(1,2),'g.'); k1=zeros(1,2); k2=zeros(1,2); n1=0; n2=0; end plot(m1(1,1),m1(1,2),'k*'); plot(m2(1,1),m2(1,2),'k*'); title('k-means聚類'); hold off;
輸出結果(未分類資料是由已分類資料去掉標籤,黑色※號表示聚類中心):
二、高斯混合模型(GMM)
回想之前之前的高斯判別分析法(GDA),是通過計算樣本的後驗概率來進行判別,而後驗概率是通過假設多元高斯模型來計算得來的。高斯模型的引數:均值、協方差,是由已標定(分類)的樣本得來,所以可以看做是一種監督學習方法。
在GMM模型(屬於無監督學習),給定未分類的m個樣本(n維特徵),假設可分為k個類,要求用GMM演算法對其進行分類。如果我們知道每個類的高斯引數,則可以向GDA演算法那樣計算出後驗概率進行判別。但遺憾的是,楊輸入的樣本未被標定,也就是說我們得不到高斯引數:均值、協方差。這就引出EM(Expectation
Maximization Algorithm:期望最大化 )演算法。
EM演算法的思想有點類似於k-means,就是通過迭代來得出最好的引數,有了這些引數就可以像GDA那樣做分類了。GMM及EM具體步驟如下:
MATLAB程式碼如下:
輸出結果:%% %GMM演算法(高斯混合模型)soft assignment(軟劃分) clear all; close all; %% k=2;%聚類數 n=2;%維數 m=200; % v0=randn(m/2,2)-1; % v1=randn(m/2,2)+1; v0=mvnrnd([1 1],[1 0;0 1],m/2);%生成正樣本1 v1=mvnrnd([4 4],[1 0;0 1],m/2);%生成負樣本0 figure;subplot(221); hold on; plot(v0(:,1),v0(:,2),'r.'); plot(v1(:,1),v1(:,2),'b.'); title('已分類資料'); hold off; %% data=[v0;v1]; data=sortrows(data,1); subplot(222); plot(data(:,1),data(:,2),'g.'); title('未分類資料'); %% mu1=mean(data(1:50,:)); mu2=mean(data(100:180,:)); sigma1=cov(data(1:50,:)); sigma2=cov(data(100:180,:)); p=zeros(m,k);%概率 thresh=0.05;%迭代終止條件 iter=0;%記錄迭代次數 while(1) iter=iter+1; A1=1/(((2*pi)^(n/2))*((det(sigma1))^(1/2))); A2=1/(((2*pi)^(n/2))*((det(sigma2))^(1/2))); for i=1:m p(i,1)=A1*exp((-1/2)*(data(i,:)-mu1)*sigma1*(data(i,:)-mu1)'); p(i,2)=A2*exp((-1/2)*(data(i,:)-mu2)*sigma2*(data(i,:)-mu2)'); pp=sum(p(i,:)); p(i,1)=p(i,1)/pp;%歸一化,樣本屬於某類的概率的總和為1 p(i,2)=p(i,2)/pp; end sum1=zeros(n,n); sum2=zeros(n,n); for i=1:m sum1=sum1+p(i,1)*(data(i,:)-mu1)'*(data(i,:)-mu1); sum2=sum2+p(i,2)*(data(i,:)-mu2)'*(data(i,:)-mu2); end sigma1=sum1/sum(p(:,1)); sigma2=sum2/sum(p(:,2)); mu1_pre=mu1; mu2_pre=mu2; mu1=(p(:,1)'*data)/sum(p(:,1)); mu2=(p(:,2)'*data)/sum(p(:,2)); if ((pdist2(mu1_pre,mu1)<=thresh) || (pdist2(mu2_pre,mu2)<=thresh)) break; end end %% subplot(223); hold on; A1=1/(((2*pi)^(n/2))*((det(sigma1))^(1/2))); A2=1/(((2*pi)^(n/2))*((det(sigma2))^(1/2))); for i=1:m p(i,1)=A1*exp((-1/2)*(data(i,:)-mu1)*sigma1*(data(i,:)-mu1)'); p(i,2)=A2*exp((-1/2)*(data(i,:)-mu2)*sigma2*(data(i,:)-mu2)'); if p(i,1)>=p(i,2) plot(data(i,1),data(i,2),'r.'); else plot(data(i,1),data(i,2),'b.'); end end title('GMM分類'); hold off; %完