1. 程式人生 > >EM(期望最大演算法)在高斯混合模型中的python實現

EM(期望最大演算法)在高斯混合模型中的python實現

以下程式碼僅實現了兩個高斯混合模型在均勻分佈條件下的引數估計,想要實現完全隨機的非均勻分佈的多高斯混合模型,可在上面加以修改。具體參考書中的9.3.2節

##python實現##
import math
#import copy
import numpy as np
import matplotlib.pyplot as plt

isdebug = False

# 指定k個高斯分佈引數。這裡指定k=2。

#注意2個高斯分佈具有同樣均方差Sigma。分別為Mu1,Mu2。


def ini_data(Sigma,Mu,k,N):
    global X
    global uMu
    global
Expectations global S X = np.zeros((1,N)) uMu = np.random.random(2)*5 S = np.random.random(2)*4 #uMu = np.array([10,30]) #S = np.array([5,2]) Expectations = np.zeros((N,k)) for i in range(0,N): if np.random.random(1) > 0.5: X[0,i] = np.random.normal()*Sigma[0
] + Mu[0] else: X[0,i] = np.random.normal()*Sigma[1] + Mu[1] if(not isdebug): print("***********") print(u"初始觀測資料X:") print(X) # EM演算法:步驟1。計算E[zij] def e_step(Sigma,k,N): global Expectations global uMu global X global S for i in range(0
,N): Denom = 0 for j in range(0,k): Denom += 0.5*(1/(float(S[j]*math.sqrt(2*math.pi))))*math.exp((-1/(2*(float(Sigma[j]**2))))*(float(X[0,i]-uMu[j]))**2) #print(Denom) for j in range(0,k): Numer = 0.5*(1/(float(S[j]*math.sqrt(2*math.pi))))*math.exp((-1/(2*(float(Sigma[j]**2))))*(float(X[0,i]-uMu[j]))**2) Expectations[i,j] = Numer / Denom if(isdebug): print("***********") print(u"隱藏變數E(Z):") #print(Expectations) # EM演算法:步驟2。求最大化E[zij]的引數Mu def m_step(k,N): global Expectations global X for j in range(0,k): Numer = 0 Denom = 0 sumSi = 0 for i in range(0,N): Numer += Expectations[i,j]*X[0,i] Denom +=Expectations[i,j] uMu[j] = Numer / Denom for i in range(0,N): sumSi += Expectations[i,j]*((X[0,i]-uMu[j])**2) #Denom +=Expectations[i,j] #print('sumSi ' + str(sumSi)) #print('Denom ' + str(Denom)) S[j] = math.sqrt(sumSi / Denom) # 演算法迭代iter_num次,或達到精度Epsilon停止迭代 def run(Sigma,Mu,k,N,iter_num,Epsilon): ini_data(Sigma,Mu,k,N) print(uMu) for i in range(iter_num): print(i) #Old_uMu = copy.deepcopy(uMu) e_step(Sigma,k,N) m_step(k,N) print(uMu) print(S) ''' if(sum(abs(uMu-Old_uMu)) < Epsilon): break ''' if __name__ == '__main__': run([6,15],[48,156],2,10000,200,0.0001) plt.hist(X[0,:],50) plt.show()