1. 程式人生 > >【EM演算法】在高斯混合模型中的應用及python示例

【EM演算法】在高斯混合模型中的應用及python示例

一、EM演算法

EM演算法是一種迭代演算法,用於含有隱含變數的概率模型引數的極大似然估計。設Y為觀測隨機變數的資料,Z為隱藏的隨機變數資料,Y和Z一起稱為完全資料。

觀測資料的似然函式為:


模型引數θ的極大似然估計為:


這個問題只有通過迭代求解,下面給出EM演算法的迭代求解過程:

step1、選擇合適的引數初值θ(0),開始迭代

step2、E步,求期望。θ(i)為第i次迭代θ的估計值,在第i+1步,計算下面的Q函式:


Q函式為logP(Y,Z|θ)關於在給定觀測資料Y和當前引數θ(i)下對隱藏變數Z的條件概率分佈P(Z|Y,θ(i))的期望。

step3、M步,求極大化。求使Q函式極大化的θ,確定第i+1次迭代的引數估計:


step4、重複第2、3步,直到收斂。

EM演算法對初值的選取比較敏感,且不能保證找到全域性最優解。

二、在高斯混合模型(GMM)中的應用

一維高斯混合模型:


多維高斯混合模型:


wk(k=1,2,……,K)為混合項係數,和為1。∑為協方差矩陣。θ=(wk,uk,σk)。

設有N個可觀測資料yi,它們是這樣產生的:先根據概率wk選擇第k個高斯分佈模型,生成觀測資料yi。yi是已知的,但yi屬於第j個模型是未知的,是隱藏變數。用Zij表示隱藏變數,含義是第i個數據屬於第j個模型的概率。先寫出完全資料的似然函式,然後確定Q函式,要最大化期望,對wk、uk、σk求偏導並使其為0。可得高斯混合模型引數估計的EM演算法(以高維資料為例):

step1、引數賦初始值,開始迭代

step2、E步,計算混合項係數Zij的期望E[Zij]:


step3、M步,計算新一輪迭代的引數模型:




step4、重複第2、3步,直到收斂。

三、python程式示例

此示例程式隨機從4個高斯模型中生成500個2維資料,真實引數:混合項w=[0.1,0.2,0.3,0.4],均值u=[[5,35],[30,40],[20,20],[45,15]],協方差矩陣∑=[[30,0],[0,30]]。然後以這些資料作為觀測資料,根據EM演算法來估計以上引數(此程式未估計協方差矩陣)。原始碼如下:

import math
import copy
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

#生成隨機資料,4個高斯模型
def generate_data(sigma,N,mu1,mu2,mu3,mu4,alpha):
    global X                  #可觀測資料集
    X = np.zeros((N, 2))       # 初始化X,2行N列。2維資料,N個樣本
    X=np.matrix(X)
    global mu                 #隨機初始化mu1,mu2,mu3,mu4
    mu = np.random.random((4,2))
    mu=np.matrix(mu)
    global excep              #期望第i個樣本屬於第j個模型的概率的期望
    excep=np.zeros((N,4))
    global alpha_             #初始化混合項係數
    alpha_=[0.25,0.25,0.25,0.25]
    for i in range(N):
        if np.random.random(1) < 0.1:  # 生成0-1之間隨機數
            X[i,:]  = np.random.multivariate_normal(mu1, sigma, 1)     #用第一個高斯模型生成2維資料
        elif 0.1 <= np.random.random(1) < 0.3:
            X[i,:] = np.random.multivariate_normal(mu2, sigma, 1)      #用第二個高斯模型生成2維資料
        elif 0.3 <= np.random.random(1) < 0.6:
            X[i,:] = np.random.multivariate_normal(mu3, sigma, 1)      #用第三個高斯模型生成2維資料
        else:
            X[i,:] = np.random.multivariate_normal(mu4, sigma, 1)      #用第四個高斯模型生成2維資料

    print("可觀測資料:\n",X)       #輸出可觀測樣本
    print("初始化的mu1,mu2,mu3,mu4:",mu)      #輸出初始化的mu

def e_step(sigma,k,N):
    global X
    global mu
    global excep
    global alpha_
    for i in range(N):
        denom=0
        for j in range(0,k):
            denom += alpha_[j]*math.exp(-(X[i,:]-mu[j,:])*sigma.I*np.transpose(X[i,:]-mu[j,:]))/np.sqrt(np.linalg.det(sigma))       #分母
        for j in range(0,k):
            numer = math.exp(-(X[i,:]-mu[j,:])*sigma.I*np.transpose(X[i,:]-mu[j,:]))/np.sqrt(np.linalg.det(sigma))        #分子
            excep[i,j]=alpha_[j]*numer/denom      #求期望
    print("隱藏變數:\n",excep)

def m_step(k,N):
    global excep
    global X
    global alpha_
    for j in range(0,k):
        denom=0   #分母
        numer=0   #分子
        for i in range(N):
            numer += excep[i,j]*X[i,:]
            denom += excep[i,j]
        mu[j,:] = numer/denom    #求均值
        alpha_[j]=denom/N        #求混合項係數

if __name__ == '__main__':
    iter_num=1000  #迭代次數
    N=500         #樣本數目
    k=4            #高斯模型數
    probility = np.zeros(N)    #混合高斯分佈
    u1=[5,35]
    u2=[30,40]
    u3=[20,20]
    u4=[45,15]
    sigma=np.matrix([[30, 0], [0, 30]])               #協方差矩陣
    alpha=[0.1,0.2,0.3,0.4]         #混合項係數
    generate_data(sigma,N,u1,u2,u3,u4,alpha)     #生成資料
    #迭代計算
    for i in range(iter_num):
        err=0     #均值誤差
        err_alpha=0    #混合項係數誤差
        Old_mu = copy.deepcopy(mu)
        Old_alpha = copy.deepcopy(alpha_)
        e_step(sigma,k,N)     # E步
        m_step(k,N)           # M步
        print("迭代次數:",i+1)
        print("估計的均值:",mu)
        print("估計的混合項係數:",alpha_)
        for z in range(k):
            err += (abs(Old_mu[z,0]-mu[z,0])+abs(Old_mu[z,1]-mu[z,1]))      #計算誤差
            err_alpha += abs(Old_alpha[z]-alpha_[z])
        if (err<=0.001) and (err_alpha<0.001):     #達到精度退出迭代
            print(err,err_alpha)
            break
    #視覺化結果
    # 畫生成的原始資料
    plt.subplot(221)
    plt.scatter(X[:,0], X[:,1],c='b',s=25,alpha=0.4,marker='o')    #T散點顏色,s散點大小,alpha透明度,marker散點形狀
    plt.title('random generated data')
    #畫分類好的資料
    plt.subplot(222)
    plt.title('classified data through EM')
    order=np.zeros(N)
    color=['b','r','k','y']
    for i in range(N):
        for j in range(k):
            if excep[i,j]==max(excep[i,:]):
                order[i]=j     #選出X[i,:]屬於第幾個高斯模型
            probility[i] += alpha_[int(order[i])]*math.exp(-(X[i,:]-mu[j,:])*sigma.I*np.transpose(X[i,:]-mu[j,:]))/(np.sqrt(np.linalg.det(sigma))*2*np.pi)    #計算混合高斯分佈
        plt.scatter(X[i, 0], X[i, 1], c=color[int(order[i])], s=25, alpha=0.4, marker='o')      #繪製分類後的散點圖
    #繪製三維影象
    ax = plt.subplot(223, projection='3d')
    plt.title('3d view')
    for i in range(N):
        ax.scatter(X[i, 0], X[i, 1], probility[i], c=color[int(order[i])])
    plt.show()

結果如下:

混合項係數估計為[0.46878064954123966, 0.087906620835838722, 0.25716577653788636, 0.18614695308503548]

均值估計為[[ 45.20736093  15.47819894]
 [  3.74835753  34.93029857]
 [ 19.97541696  20.26373867]
 [ 29.91276386  39.87999686]]


左上圖為生成的觀測資料,右上圖為分類後的結果,下圖為高斯混合模型的三維視覺化圖。