斯坦福大學機器學習——EM演算法求解高斯混合模型
EM演算法(Expection-Maximizationalgorithm,EM)是一種迭代演算法,通過E步和M步兩大迭代步驟,每次迭代都使極大似然函式增加。但是,由於初始值的不同,可能會使似然函式陷入區域性最優。辜麗川老師和其夫人發表的論文:基於分裂EM演算法的GMM引數估計(提取碼:77c0)改進了這一缺陷。下面來談談EM演算法以及其在求解高斯混合模型中的作用。
一、 高斯混合模型(Gaussian MixtureModel, GMM)
之前寫過高斯判別分析模型,利用引數估計的方法用於解決二分類問題。下面介紹GMM,它是對高斯判別模型的一個推廣,也能借此引入EM演算法。
假設樣本集為並且樣本和標籤滿足聯合分佈
該模型的似然函式為:
如果直接令的各變數偏導為0,試圖分別求出各引數,我們會發現根本無法求解。但如果變數是已知的,求解便容易許多,上面的似然函式可以表示為:
其中,#{ }為指示函式,表示滿足括號內條件的數目。
那麼,變數無法通過觀察直接得到,就稱為隱變數,就需要通過EM演算法,求解GMM了。下面從Jensen不等式開始,介紹下EM演算法:
二、 Jensen不等式(Jensen’s inequality)
引理:如果函式f的定義域為整個實數集,並且對於任意x或
(存在或函式的Hessian矩陣,那麼函式f稱為凸函式;如果或函式的Hessian矩陣 H<0,那麼函式f為嚴格凸函式。)
定理:如果函式f是凹函式,X為隨機變數,那麼:
不幸的是很多人都會講Jensen不等式記混,我們可以通過圖形的方式幫助記憶。下圖中,橫縱座標軸分別為X和f(X),f(x)為一個凹函式,a、b分別為變數X的定義域,E[X]為定義域X的期望。圖中清楚的看到各個量的位置和他們間的大小關係。反之,如果函式f是凸函式,X為隨機變數,那麼:
三、 EM演算法
假設訓練集是由m個獨立的樣本構成。我們的目的是要對概率密度函式進行引數估計。它的似然函式為:
然而僅僅憑藉似然函式,無法對引數進行求解。因為這裡的隨機變數是未知的。
EM演算法提供了一種巧妙的方式,可以通過逐步迭代逼近最大似然值。下面就來介紹下EM演算法:
其中第(2)步至第(3)步的推導就使用了Jensen不等式。其中:f(x)=log x,,因此為凸函式;表示隨機變數為概率分佈函式為的期望。因此有:
這樣,對於任意分佈,(3)都給出了的一個下界。如果我們現在通過猜測初始化了一個的值,我們希望得到在這個特定的下,更緊密的下界,也就是使等號成立。根據Jensen不等式等號成立的條件,當為一常數時,等號成立。即:
上述等式最後一步使用了貝葉斯公示。
EM演算法有兩個步驟:
(1)通過設定初始化值,求出使似然方程最大的值,此步驟稱為E-步(E-step)
(2)利用求出的值,更新。此步驟稱為M-步(M-step)。過程如下:
repeat until convergence{
(E-step) for each i, set
(M-step) set
}
那麼,如何保證EM演算法是收斂的呢?下面給予證明:
假設和是EM演算法第t次和第t+1次迭代所得到的引數的值,如果有,即每次迭代後似然方程的值都會增大,通過逐步迭代,最終達到最大值。以下是證明:
不等式(4)是由不等式(3)得到,對於任意和值都成立;得到不等式(5)是因為我們需要選擇特定的使得方程在處的值大於在處的值;等式(6)是找到特定的的值,使得等號成立。
最後我們通過圖形的方式再更加深入細緻的理解EM演算法的特點:
由上文我們知道有這樣的關係:,EM演算法就是不斷最大化這個下界,逐步得到似然函式的最大值。如下圖所示:
首先,初始化,調整使得與相等,然後求出使得到最大值的;固定,調整,使得與相等,然後求出使得到最大值的;……;如此迴圈,使得的值不斷上升,直到k次迴圈後,求出了的最大值。
四、 EM演算法應用於混合高斯模型(GMM)
再回到GMM:上文提到由於隱變數的存在,無法直接求解引數,但可以通過EM演算法進行求解:E-Step:
令上式為0,得:
(2)引數觀察M-Step,可以看到,跟相關的變數僅僅有。因此,我們僅僅需要最大化下面的目標函式:
又由於,為約束條件。因此,可以構造拉格朗日運算元求目標函式:
求偏導:
令上式為零,解得:
五、 總結
EM演算法利用不完全的資料,進行極大似然估計。通過兩步迭代,逐漸逼近最大似然值。而GMM可以利用EM演算法進行引數估計。
最後提下辜老師論文的思路:EM模型容易收斂到區域性最大值,並且嚴重依賴初試值。傳統的方法即上文中使用的方法是每次迭代過程中,同時更新高斯分佈中所有引數,而辜老師的方法是把K個高斯分佈中的一個分量,利用奇異值分解的方法將其分裂為兩個高斯分佈,並保持其他分量不變的情況下,對共這K+1個高斯分佈的權值進行更新,直到符合一定的收斂條件。這樣一來,雖然演算法複雜度沒有降低,但每輪只需要更新兩個引數,大大降低了每輪迭代的計算量。