【EM演算法】在高斯混合模型中的應用及python示例
一、EM演算法
EM演算法是一種迭代演算法,用於含有隱含變數的概率模型引數的極大似然估計。設Y為觀測隨機變數的資料,Z為隱藏的隨機變數資料,Y和Z一起稱為完全資料。
觀測資料的似然函式為:
模型引數θ的極大似然估計為:
這個問題只有通過迭代求解,下面給出EM演算法的迭代求解過程:
step1、選擇合適的引數初值θ(0),開始迭代
step2、E步,求期望。θ(i)為第i次迭代θ的估計值,在第i+1步,計算下面的Q函式:
Q函式為logP(Y,Z|θ)關於在給定觀測資料Y和當前引數θ(i)下對隱藏變數Z的條件概率分佈P(Z|Y,θ(i))的期望。
step3、M步,求極大化。求使Q函式極大化的θ,確定第i+1次迭代的引數估計:
step4、重複第2、3步,直到收斂。
EM演算法對初值的選取比較敏感,且不能保證找到全域性最優解。
二、在高斯混合模型(GMM)中的應用
一維高斯混合模型:
多維高斯混合模型:
wk(k=1,2,……,K)為混合項係數,和為1。∑為協方差矩陣。θ=(wk,uk,σk)。
設有N個可觀測資料yi,它們是這樣產生的:先根據概率wk選擇第k個高斯分佈模型,生成觀測資料yi。yi是已知的,但yi屬於第j個模型是未知的,是隱藏變數。用Zij表示隱藏變數,含義是第i個數據屬於第j個模型的概率。先寫出完全資料的似然函式,然後確定Q函式,要最大化期望,對wk、uk、σk求偏導並使其為0。可得高斯混合模型引數估計的EM演算法(以高維資料為例):
step1、引數賦初始值,開始迭代
step2、E步,計算混合項係數Zij的期望E[Zij]:
step3、M步,計算新一輪迭代的引數模型:
step4、重複第2、3步,直到收斂。
三、python程式示例
此示例程式隨機從4個高斯模型中生成500個2維資料,真實引數:混合項w=[0.1,0.2,0.3,0.4],均值u=[[5,35],[30,40],[20,20],[45,15]],協方差矩陣∑=[[30,0],[0,30]]。然後以這些資料作為觀測資料,根據EM演算法來估計以上引數(此程式未估計協方差矩陣)。原始碼如下:
import math import copy import numpy as np import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D #生成隨機資料,4個高斯模型 def generate_data(sigma,N,mu1,mu2,mu3,mu4,alpha): global X #可觀測資料集 X = np.zeros((N, 2)) # 初始化X,2行N列。2維資料,N個樣本 X=np.matrix(X) global mu #隨機初始化mu1,mu2,mu3,mu4 mu = np.random.random((4,2)) mu=np.matrix(mu) global excep #期望第i個樣本屬於第j個模型的概率的期望 excep=np.zeros((N,4)) global alpha_ #初始化混合項係數 alpha_=[0.25,0.25,0.25,0.25] for i in range(N): if np.random.random(1) < 0.1: # 生成0-1之間隨機數 X[i,:] = np.random.multivariate_normal(mu1, sigma, 1) #用第一個高斯模型生成2維資料 elif 0.1 <= np.random.random(1) < 0.3: X[i,:] = np.random.multivariate_normal(mu2, sigma, 1) #用第二個高斯模型生成2維資料 elif 0.3 <= np.random.random(1) < 0.6: X[i,:] = np.random.multivariate_normal(mu3, sigma, 1) #用第三個高斯模型生成2維資料 else: X[i,:] = np.random.multivariate_normal(mu4, sigma, 1) #用第四個高斯模型生成2維資料 print("可觀測資料:\n",X) #輸出可觀測樣本 print("初始化的mu1,mu2,mu3,mu4:",mu) #輸出初始化的mu def e_step(sigma,k,N): global X global mu global excep global alpha_ for i in range(N): denom=0 for j in range(0,k): denom += alpha_[j]*math.exp(-(X[i,:]-mu[j,:])*sigma.I*np.transpose(X[i,:]-mu[j,:]))/np.sqrt(np.linalg.det(sigma)) #分母 for j in range(0,k): numer = math.exp(-(X[i,:]-mu[j,:])*sigma.I*np.transpose(X[i,:]-mu[j,:]))/np.sqrt(np.linalg.det(sigma)) #分子 excep[i,j]=alpha_[j]*numer/denom #求期望 print("隱藏變數:\n",excep) def m_step(k,N): global excep global X global alpha_ for j in range(0,k): denom=0 #分母 numer=0 #分子 for i in range(N): numer += excep[i,j]*X[i,:] denom += excep[i,j] mu[j,:] = numer/denom #求均值 alpha_[j]=denom/N #求混合項係數 if __name__ == '__main__': iter_num=1000 #迭代次數 N=500 #樣本數目 k=4 #高斯模型數 probility = np.zeros(N) #混合高斯分佈 u1=[5,35] u2=[30,40] u3=[20,20] u4=[45,15] sigma=np.matrix([[30, 0], [0, 30]]) #協方差矩陣 alpha=[0.1,0.2,0.3,0.4] #混合項係數 generate_data(sigma,N,u1,u2,u3,u4,alpha) #生成資料 #迭代計算 for i in range(iter_num): err=0 #均值誤差 err_alpha=0 #混合項係數誤差 Old_mu = copy.deepcopy(mu) Old_alpha = copy.deepcopy(alpha_) e_step(sigma,k,N) # E步 m_step(k,N) # M步 print("迭代次數:",i+1) print("估計的均值:",mu) print("估計的混合項係數:",alpha_) for z in range(k): err += (abs(Old_mu[z,0]-mu[z,0])+abs(Old_mu[z,1]-mu[z,1])) #計算誤差 err_alpha += abs(Old_alpha[z]-alpha_[z]) if (err<=0.001) and (err_alpha<0.001): #達到精度退出迭代 print(err,err_alpha) break #視覺化結果 # 畫生成的原始資料 plt.subplot(221) plt.scatter(X[:,0], X[:,1],c='b',s=25,alpha=0.4,marker='o') #T散點顏色,s散點大小,alpha透明度,marker散點形狀 plt.title('random generated data') #畫分類好的資料 plt.subplot(222) plt.title('classified data through EM') order=np.zeros(N) color=['b','r','k','y'] for i in range(N): for j in range(k): if excep[i,j]==max(excep[i,:]): order[i]=j #選出X[i,:]屬於第幾個高斯模型 probility[i] += alpha_[int(order[i])]*math.exp(-(X[i,:]-mu[j,:])*sigma.I*np.transpose(X[i,:]-mu[j,:]))/(np.sqrt(np.linalg.det(sigma))*2*np.pi) #計算混合高斯分佈 plt.scatter(X[i, 0], X[i, 1], c=color[int(order[i])], s=25, alpha=0.4, marker='o') #繪製分類後的散點圖 #繪製三維影象 ax = plt.subplot(223, projection='3d') plt.title('3d view') for i in range(N): ax.scatter(X[i, 0], X[i, 1], probility[i], c=color[int(order[i])]) plt.show()
結果如下:
混合項係數估計為[0.46878064954123966, 0.087906620835838722, 0.25716577653788636, 0.18614695308503548]
均值估計為[[ 45.20736093 15.47819894]
[ 3.74835753 34.93029857]
[ 19.97541696 20.26373867]
[ 29.91276386 39.87999686]]
左上圖為生成的觀測資料,右上圖為分類後的結果,下圖為高斯混合模型的三維視覺化圖。