1. 程式人生 > >機器學習筆記(十)EM演算法及實踐(以混合高斯模型(GMM)為例來次完整的EM)

機器學習筆記(十)EM演算法及實踐(以混合高斯模型(GMM)為例來次完整的EM)

今天要來討論的是EM演算法。第一眼看到EM我就想到了我大楓哥,EM Master,千里馬,RUA!!!不知道看這個部落格的人有沒有懂這個梗的。好的,言歸正傳,今天要講的EM演算法,全稱是Expectation maximization,期望最大化。怎麼個意思呢,就是給你一堆觀測樣本,讓你給出這個模型的引數估計。我靠,這套路我們前面討論各種迴歸的時候不是已經用爛了嗎?求期望,求對數期望,求導為0,得到引數估計值,這套路我懂啊,MLE!但問題在於,如果這個問題存在中間的隱變數呢?會不會把我們的套路給帶崩呢,我們通過兩個例子來認識一下這兩種情況。

====================================================================

不存在中間變數的EM。

假設有一天人類消除了性別的差別,所有的人都是同一個性別。這個時候,我給了你一群人的身高讓你給我做一個估計人身高的模型。

怎麼辦呢?感覺上身高應該是服從高斯分佈吧,所以假設人的身高分佈服從高斯分佈N(Mu,Sigma^2),現在我又有了N個人的身高的資料,我就可以照著上面的套路進行了。先求對數似然函式

接著對兩個引數求偏導為0

這樣就得到了我們的引數估計

喜聞樂見的結果,又好求又符合我們的直覺,那我們再來看看另一種情況。

====================================================================

存在中間變數的EM。

正如你所知,身高和人種的關係挺大的,而人類又有那麼多種族,所以,再給你一堆人的資料,要做一個估計人身高的模型,那我們應該怎麼做呢?

首先,現在分為不同種族若干類了,這些類別的概率肯定有個分佈,其次,各種族當中身高是服從不同的分佈的,那麼這樣身高的估計就變成了

Alpha代表了該樣本屬於某一人種的比例,其實就是隱藏的中間變數。Muk和sigmak^2為各類高斯分佈的引數。按照我們上面的套路就是求對數似然概率再求導得到引數的估計,那麼先來看看似然函式

這下尷尬的情況出現了,對數裡面帶加號,這下求導就變得複雜異常了,而且沒法求解,事實上,這種式子確實沒有解析解。不過憋灰心啊,假設我們隨便猜一個alpha的分佈為Q,那麼對數似然函式可以寫成

由於Q是alpha的一個分佈,所以似然函式可以看成是一個log(E(x)),log是個凹函式啊,割線始終在函式影象下方,Jensen不等式反向應用一下,有log(E(x))>=E(log(x)),所以上面的對數似然有

冷靜分析一下現在的情況,我們現在得到了一個對數似然函式的下界函式,我們採用曲線救國的戰略,我們求解它的區域性最大值,那麼更新後的引數帶入這個下界函式一定比之前的引數值大,而它本身又是對數似然函式的下界函式,所以引數更新後,我們的對數似然函式一定是變大了!所以,就利用這種方法進行迭代,最後就能得到比較好的引數估計。還有點暈嗎,沒事,我從百度扒個圖給你來個形象的解釋

紅色那條線就是我們的對數似然函式,藍色那條是我們在當前引數下找到的對數似然的下界函式,可以看到,我們找到它的區域性極值那麼引數更新成thetanew,此時對數似然函式的值也得到了上升,這樣重複進行下去,是不是就可以收斂到對數似然函式的一個區域性極值了嘛。對的,區域性極值,並不能保證是全域性最優,但它就是個估計嘛,你還要她怎樣?!

到了這裡,我們好像跳著先把第二步引數更新的工作做完了,那麼還有一個事情是我們要注意的,Q呢,Q是啥,沒Q你算啥極值,更新啥引數。我們已經知道Q是alpha的一個分佈,然後我們肯定是希望這個下界函式儘量貼近原來的對數似然函式,這樣我們才能更快地更新引數,那下界函式啥時最大呢,等號成立唄,等號成立說明你求期望的物件是個常數呀,所以log和Q誰在前後都無所謂,那麼就有了

直觀地可以理解成第i個樣本來自第k個類別的可能性。好了,現在Q也確定了,我們根據上面所說的方法更新引數,再更新Q,再更新引數,迭代進行下去就可以了。

如果你能堅持看到這裡,少俠我只能說你大功已成。因為其實我們已經把EM演算法整個推導完了,也許你還是有點雲裡霧裡,那我們再來仔細梳理一下這個流程

1 拿到所有的觀測樣本,根據先驗或者喜好先給一個引數估計。

2 根據這個引數估計和樣本計算類別分佈Q,得到最貼近對數似然函式的下界函式。

3 對下界函式求極值,更新引數分佈。

4 迭代計算,直至收斂。

說起來啊,EM演算法據說是機器學習進階的一個演算法,但至少目前來看,它的思路還是很容易理解的嘛,整個過程中唯一一個可能初學者覺得有點繞的地方就是應用Jensen不等式的那一步,那我再囉嗦兩句。所謂Jensen不等式,你可以這麼理解,對於一個凸函式f而言,它的割線始終在函式影象上方你承認吧,我在上面任取兩點x1,x2,引數theta介於0到1之間,那麼theta*x1+(1-theta)*x2就是介於x1和x2之間的一點吧,在這點上過x1x2割線的值大於函式值吧,是不是就有了theta*f(x1)+(1-theta)*f(x2)>f(theta*x1+(1-theta)*x2),根據這個結論再推廣開來,就有E(f(x))>f(E(x)),在對數似然函式中,由於log是個凹函式,所以把它反過來用,老鐵沒毛病吧?!這一點想通了我覺得整個EM演算法的流程還是蠻好懂的。

下面呢,我們還回到這個身高模型的預測,假設給了m個樣本,有k個種族,每個種族的身高都是服從高斯分佈的,那麼這就變成了EM演算法中最具代表性的一個例子,高斯混合模型GMM。

====================================================================高斯混合模型(GMM)

剛才已經講了EM演算法的套路了,假設我們現在處於某一步迭代中,那麼我們該幹嘛呢?

E-step 求最佳的類別分佈

可以將其理解為第i個樣本屬於第J類的概率。

M-step 更新引數

求得了Q之後,我們就得到了最貼近原對數似然函式的下界函式,那我們對它求極值就可以得到更新後的引數,先看一下這個下界函式

Log函式裡面全是乘積項這是我們最喜歡的形式,這樣求導的時候但凡不相關的我們直接扔掉就行,待求引數mu,sigma^2,psi,依次求導為0就成。

對於psi的求解可能複雜一些,首先我們把下界函式中與psi不相關的項全部去掉,然後psi作為各類別的比例有一個天然的約束條件就是所有的psi之和為1,所以目標函式變成

這種帶約束的優化前面在SVM的時候不知道用了多少回,拉格朗日乘子法

接下來對psi求導

兩邊同時再對j從1到k連加,psi那一項就沒了,右式就變成樣本數目m,這樣就求得了beta,回代我們就可以求得psi引數的更新

至此,所有的引數更新工作就已完成,下面重複進行迭代就行了。

我們先把GMM的演算法梳理一下

1 給引數取初始值,開始迭代。

2 求每個樣本對每個類別的概率,科學的叫法叫求響應度

3 更新模型引數

4 重複23兩步直至收斂。

我們再來看看這些引數的意義,其實未嘗不符合我們的直覺認識。W(i,j)可以看做第i個樣本屬於第j類的概率,那麼所有樣本中屬於第j類的個數就是w(i,j)之和,每個樣本xi對應第j類的值就是W(i,j) xi,這樣算的平均數就是第j類對應的mu,繼續按照這個思路算的方差就是第j類的sigma^2,第j類的概率就是第j類的個數除以總樣本數。所以,GMM模型雖然推導起來有點嚇人,但仔細想想它最後的結果也是符合我們的直覺認識的,每個樣本都是一部分屬於某一類,所有樣本中的某一類的部分構成了這一類的分佈,perfect!!!

====================================================================

這樣的話,理論部分我們就講完了,接下來又是調包俠的時刻了,上次寫完後我想到鳶尾花資料無監督演算法也能做啊,不給標籤我們強行給它分類看看效果如何。所以這裡我們K-Means和GMM演算法分別對鳶尾花進行處理,看看它們的聚類效果如何。

程式碼如下

  1. import numpy as np

  2. from sklearn import datasets

  3. from sklearn.cluster import KMeans

  4. from sklearn.mixture import GaussianMixture

  5. #讀取資料

  6. iris=datasets.load_iris()

  7. x=iris.data[:,:2]

  8. y=iris.target

  9. mu = np.array([np.mean(x[y == i], axis=0) for i in range(3)])

  10. print '實際均值 = \n', mu

  11. #K-Means

  12. kmeans=KMeans(n_clusters=3,init='k-means++',random_state=0)

  13. y_hat1=kmeans.fit_predict(x)

  14. mu1=np.array([np.mean(x[y_hat1 == i], axis=0) for i in range(3)])

  15. print 'K-Means均值 = \n', mu1

  16. print '分類正確率為',np.mean(y_hat1==y)

  17. gmm=GaussianMixture(n_components=3,covariance_type='full', random_state=0)

  18. gmm.fit(x)

  19. print 'GMM均值 = \n', gmm.means_

  20. y_hat2=gmm.predict(x)

  21. print '分類正確率為',np.mean(y_hat2==y)

輸出結果為

實際均值 =

[[5.006  3.418]

 [5.936  2.77 ]

 [6.588  2.974]]

K-Means均值 =

[[5.77358491  2.69245283]

 [ 5.006      3.418     ]

 [ 6.81276596 3.07446809]]

分類正確率為 0.233333333333

GMM均值 =

[[5.01494511  3.44040237]

 [ 6.69225795 3.03018616]

 [ 5.90652226 2.74740414]]

分類正確率為 0.533333333333怒摔鍵盤啊,什麼破正確率呀!!!憋急啊,我看事情並不簡單。機智的我們觀察一下均值矩陣。K-Means給出的第一行似乎和實際的第二行很接近,第二行和實際的第一行很接近。同樣,GMM給出的均值矩陣也有同樣的問題,第二行和第三行似乎對調了。這不是演算法的鍋啊,它只管給你聚類,哪裡還能保證標籤和你一樣啊,三個類別六種標籤方式人家演算法也只能隨機一種好嗎,所以現在我們把預測的結果的標籤改一下看看實際的正確率如何。

  1. import numpy as np

  2. from sklearn import datasets

  3. from sklearn.cluster import KMeans

  4. from sklearn.mixture import GaussianMixture

  5. #讀取資料

  6. iris=datasets.load_iris()

  7. x=iris.data[:,:2]

  8. y=iris.target

  9. mu = np.array([np.mean(x[y == i], axis=0) for i in range(3)])

  10. print '實際均值 = \n', mu

  11. #K-Means

  12. kmeans=KMeans(n_clusters=3,init='k-means++',random_state=0)

  13. y_hat1=kmeans.fit_predict(x)

  14. y_hat1[y_hat1==0]=3

  15. y_hat1[y_hat1==1]=0

  16. y_hat1[y_hat1==3]=1

  17. mu1=np.array([np.mean(x[y_hat1 == i], axis=0) for i in range(3)])

  18. print 'K-Means均值 = \n', mu1

  19. print '分類正確率為',np.mean(y_hat1==y)

  20. gmm=GaussianMixture(n_components=3,covariance_type='full', random_state=0)

  21. gmm.fit(x)

  22. print 'GMM均值 = \n', gmm.means_

  23. y_hat2=gmm.predict(x)

  24. y_hat2[y_hat2==1]=3

  25. y_hat2[y_hat2==2]=1

  26. y_hat2[y_hat2==3]=2

  27. print '分類正確率為',np.mean(y_hat2==y)

輸出結果為

實際均值 =

[[5.006  3.418]

 [ 5.936 2.77 ]

 [ 6.588 2.974]]

K-Means均值 =

[[5.006       3.418     ]

 [ 5.77358491 2.69245283]

 [ 6.81276596 3.07446809]]

分類正確率為 0.82

GMM均值 =

[[5.01494511  3.44040237]

 [ 6.69225795 3.03018616]

 [ 5.90652226 2.74740414]]

分類正確率為 0.786666666667

啊,這樣的結果還是比較讓人滿意的,甚至比前面的一些監督學習的結果還要好一些……另外,標籤不一致的問題我這裡採用的是最蠢的手動調整,大家當然可以根據你算出的均值矩陣每行與原始均值矩陣哪行的距離最小,確定它在原始資料中的標籤自動調整,這當然是OK的,我這裡偷一點懶。