1. 程式人生 > >【sklearn】利用sklearn訓練LDA主題模型及調參詳解

【sklearn】利用sklearn訓練LDA主題模型及調參詳解

人生苦短,我愛python,尤愛sklearn。sklearn不僅提供了機器學習基本的預處理、特徵提取選擇、分類聚類等模型介面,還提供了很多常用語言模型的介面,sklearn.decomposition.LatentDirichletAllocation就是其中之一。本文除了介紹LDA模型的基本引數、呼叫訓練以外,還將提供幾種LDA調參的可行策略,供大家參考討論。考慮到篇幅,本文將略去LDA原理證明的部分,想要學習的寶寶們請前往LDA數學八卦進行深入學習,絕對受益匪淺!

LDA主題模型訓練與調參

(1)載入語料庫及預處理

本文選用的語料庫為sklearn自帶API的20newsgroups語料庫,該語料庫包含商業、科技、運動、航空航天等多領域新聞資料,很適合NLP的初學者進行使用。

sklearn_20newsgroups給出了非常詳細的介紹。
預處理方面,直接呼叫了NLTK的介面進行小寫化、分詞、去除停用詞、POS篩選及詞幹化。這裡進行哪些操作完全根據實際需要和資料來定,比如我就經常放棄詞幹化或者放棄POS篩選(原因通常是結果不好==)…以下程式碼為載入20newsgroups資料及文字預處理部分程式碼。

#載入資料
from sklearn.datasets import fetch_20newsgroups
dataset = fetch_20newsgroups(shuffle=True, random_state=1,
                             remove=('headers'
, 'footers', 'quotes')) data_samples = dataset.data[:n_samples] #擷取需要的量,n_samples=2000 #文字預處理, 可選項 import nltk import string from nltk.corpus import stopwords from nltk.stem.porter import PorterStemmer def textPrecessing(text): #小寫化 text = text.lower() #去除特殊標點 for c in string.punctuation: text = text.replace(c, ' '
) #分詞 wordLst = nltk.word_tokenize(text) #去除停用詞 filtered = [w for w in wordLst if w not in stopwords.words('english')] #僅保留名詞或特定POS refiltered =nltk.pos_tag(filtered) filtered = [w for w, pos in refiltered if pos.startswith('NN')] #詞幹化 ps = PorterStemmer() filtered = [ps.stem(w) for w in filtered] return " ".join(filtered)

以上程式碼執行時間不長,是因為我只隨機(shuffle=True)截取了n_samples=2000條新聞。但是當語料庫較大時,通常預處理時間也會久一點。因此如果文字資料不變,最好對預處理結果進行儲存,這樣每次執行只消從檔案裡讀資料即可。

#該區域僅首次執行,進行文字預處理,第二次執行起註釋掉
docLst = []
for desc in data_samples :
    docLst.append(textPrecessing(desc).encode('utf-8'))
with open(textPre_FilePath, 'w') as f:
    for line in docLst:
        f.write(line+'\n')

#==============================================================================
#從第二次執行起,直接獲取預處理過的docLst,前面load資料、預處理均註釋掉
#docLst = []
#with open(textPre_FilePath, 'r') as f:
#    for line in f.readlines():
#        if line != '':
#            docLst.append(line.strip())
#==============================================================================

我隨便列印了兩條20newsgroups的資料和預處理後的結果,預處理時未進行POS篩選及詞幹化,以方便大家理解。

Output:
Original 20Newsgroups Articles: [u"Well i'm not sure about the story nad it did seem biased. What\nI disagree with is your statement that the U.S. Media is out to\nruin Israels reputation. That is rediculous. The U.S. media is\nthe most pro-israeli media in the world. Having lived in Europe\nI realize that incidences such as the one described in the\nletter have occured. The U.S. media as a whole seem to try to\nignore them. The U.S. is subsidizing Israels existance and the\nEuropeans are not (at least not to the same degree). So I think\nthat might be a reason they report more clearly on the\natrocities.\n\tWhat is a shame is that in Austria, daily reports of\nthe inhuman acts commited by Israeli soldiers and the blessing\nreceived from the Government makes some of the Holocaust guilt\ngo away. After all, look how the Jews are treating other races\nwhen they got power. It is unfortunate.\n",
 u'\nJames Hogan writes:\n\ntimmbake@mcl.ucsb.edu (Bake Timmons) writes:\n>>Jim Hogan quips:\n\n>>... (summary of Jim\'s stuff)\n\n>>Jim, I\'m afraid _you\'ve_ missed the point.\n\n>>>Thus, I think you\'ll have to admit that  atheists have a lot\n>>more up their sleeve than you might have suspected.\n\n>>Nah.  I will encourage people to learn about atheism to see how little atheists\n>>have up their sleeves.  Whatever I might have suspected is actually quite\n>>meager.  If you want I\'ll send them your address to learn less about your\n>>faith.\n\n>Faith?\n\nYeah, do you expect people to read the FAQ, etc. and actually accept hard\natheism?  No, you need a little leap of faith, Jimmy.  Your logic runs out\nof steam!\n\n>>>Fine, but why do these people shoot themselves in the foot and mock\n>>>the idea of a God?  ....\n\n>>>I hope you understand now.\n\n>>Yes, Jim.  I do understand now.  Thank you for providing some healthy sarcasm\n>>that would have dispelled any sympathies I would have had for your faith.\n\n>Bake,\n\n>Real glad you detected the sarcasm angle, but am really bummin\' that\n>I won\'t be getting any of your sympathy.  Still, if your inclined\n>to have sympathy for somebody\'s *faith*, you might try one of the\n>religion newsgroups.\n\n>Just be careful over there, though. (make believe I\'m\n>whispering in your ear here)  They\'re all delusional!\n\nJim,\n\nSorry I can\'t pity you, Jim.  And I\'m sorry that you have these feelings of\ndenial about the faith you need to get by.  Oh well, just pretend that it will\nall end happily ever after anyway.  Maybe if you start a new newsgroup,\nalt.atheist.hard, you won\'t be bummin\' so much?\n\n>Good job, Jim.\n>.\n\n>Bye, Bake.\n\n\n>>[more slim-Jim (tm) deleted]\n\n>Bye, Bake!\n>Bye, Bye!\n\nBye-Bye, Big Jim.  Don\'t forget your Flintstone\'s Chewables!  :) \n--\nBake Timmons, III\n\n-- "...there\'s nothing higher, stronger, more wholesome and more useful in life\nthan some good memory..." -- Alyosha in Brothers Karamazov (Dostoevsky)\n']

Articles After Preprocessing: [u'well sure story nad seem biased disagree statement u media ruin israels reputation rediculous u media pro israeli media world lived europe realize incidences one described letter occured u media whole seem try ignore u subsidizing israels existance europeans least degree think might reason report clearly atrocities shame austria daily reports inhuman acts commited israeli soldiers blessing received government makes holocaust guilt go away look jews treating races got power unfortunate',
 u'james hogan writes timmbake mcl ucsb edu bake timmons writes jim hogan quips summary jim stuff jim afraid missed point thus think admit atheists lot sleeve might suspected nah encourage people learn atheism see little atheists sleeves whatever might suspected actually quite meager want send address learn less faith faith yeah expect people read faq etc actually accept hard atheism need little leap faith jimmy logic runs steam fine people shoot foot mock idea god hope understand yes jim understand thank providing healthy sarcasm would dispelled sympathies would faith bake real glad detected sarcasm angle really bummin getting sympathy still inclined sympathy somebody faith might try one religion newsgroups careful though make believe whispering ear delusional jim sorry pity jim sorry feelings denial faith need get oh well pretend end happily ever anyway maybe start new newsgroup alt atheist hard bummin much good job jim bye bake slim jim tm deleted bye bake bye bye bye bye big jim forget flintstone chewables bake timmons iii nothing higher stronger wholesome useful life good memory alyosha brothers karamazov dostoevsky']

LDA模型學習時的訓練資料並不是一篇篇文字,而是Document-word matrix,它可以是array也可以是稀疏矩陣,維數是n_samples*n_features,其中n_features為詞(term)的個數。因此在訓練LDA主題模型前,需要先利用CountVectorizer統計詞頻並儲存,程式碼如下:

from sklearn.feature_extraction.text import CountVectorizer
from sklearn.externals import joblib  #也可以選擇pickle等儲存模型,請隨意

#構建詞彙統計向量並儲存,僅執行首次
tf_vectorizer = CountVectorizer(max_df=0.95, min_df=2,
                                max_features=n_features,
                                stop_words='english')
tf = tf_vectorizer.fit_transform(docLst)
joblib.dump(tf_vectorizer,tf_ModelPath )
#==============================================================================
# #得到儲存的tf_vectorizer,節省預處理時間
# tf_vectorizer = joblib.load(tf_ModelPath)
# tf = tf_vectorizer.fit_transform(docLst)
#==============================================================================

CountVectorizer的API請自行參考sklearn,文中程式碼限定term出現次數必須大於2,最終保留前n_features=2500的term作為features。訓練得到的tf_vectorizer 利用joblib儲存到檔案,第二次起可以直接從檔案中load進來避免重複計算。該步驟得到的tf矩陣為一個“文章-詞語”稀疏矩陣,可以通過tf_vectorizer.get_feature_names()得到每一維feature對應的term。

(3)LDA主題模型訓練

終於到了最關鍵的LDA主題模型訓練階段。雖說此階段最關鍵,但如果資料質量高,如果前面的步驟沒有偷工減料,這步其實水到渠成;反之,問題可能都會累計到此階段集中的反映出來。要想訓練優秀的主題模型,兩個重要的前提就是資料質量和文字預處理。在此特別安利一下用起來舒服的預處理包:中文–>jieba,英文–>spaCy。上文采用nltk實屬無奈,因為這臺電腦無法成功安裝spaCy唉。。
好了不跑題。LDA訓練程式碼如下,其中引數請參考最後面的附錄sklearn LDA API 中文解釋。

from sklearn.decomposition import LatentDirichletAllocation
n_topics = 30
lda = LatentDirichletAllocation(n_topics=n_topic, 
                                max_iter=50,
                                learning_method='batch')
lda.fit(tf) #tf即為Document_word Sparse Matrix                              

(4)結果展示

LDA的訓練時間根據max_iter設定的不同以及資料收斂情況的不同而差別很大。測試時max_iter設定為幾十次通常很快就會結束,當然如果實際應用的話,建議至少上千次吧。

Topic Top Words結果

def print_top_words(model, feature_names, n_top_words):
    #列印每個主題下權重較高的term
    for topic_idx, topic in enumerate(model.components_):
        print "Topic #%d:" % topic_idx
        print " ".join([feature_names[i]
                        for i in topic.argsort()[:-n_top_words - 1:-1]])
    print
    #列印主題-詞語分佈矩陣
    print model.components_

n_top_words=20
tf_feature_names = tf_vectorizer.get_feature_names()
print_top_words(lda, tf_feature_names, n_top_words)

Output:
#每個主題下權重較高的詞語
Topic #0:
mail edu thanks new send email 00 com internet interested info uk price ac know sale fax copy data following
Topic #1:
gm win rochester edu michael new fred vs adams tommy gov nick gb main hudson issue alaska nasa space people
Topic #2:
55 10 11 18 21 17 13 19 16 period 22 23 14 20 25 15 24 12 93 26
Topic #3:
color server motif software input output edu support clock 256 bits linux vga shots default mode level using image xterm
Topic #4:
edu writes article com know like uiuc cc news cs people cso opinions think david really way right heard sure
Topic #5:
section military shall dangerous firearm weapon law person state license use means following women designed islamic japanese division men issued
Topic #6:
like know time good bike com really writes course year ride going think got read live years better big high
Topic #7:
com edu writes article list andrew apple cmu cs sandvik points toronto ca kent vancouver sphere power point portal cup
Topic #8:
know ca black use white edu think writes light like signal right old used dave bnr want mouse led let
Topic #9:
drive disk drives hard controller rom card bios floppy flyers 16 feature supports board speed bus interface power mb data
Topic #10:
people government think president american weapons country clinton mr support time billion make new say like going state states jobs
Topic #11:
edu insurance hp writes article like offer cable best turbo use port power se speed hd good 25 swap year
Topic #12:
food edu msg writes article standard frank use objective red blues people bear cs area values begin like wings rick
Topic #13:
earth probe moon lunar orbit mission surface mars space spacecraft venus solar jupiter science atmosphere planet planetary images data pioneer
Topic #14:
edu com want good dog writes buy dod sold question dealer article water nec large make used chris audio hp
Topic #15:
israel jews israeli arab jewish attacks state peace people land policy lebanese arabs right say nazi writes men fact soldiers
Topic #16:
com gun writes guns article crime 000 self edu likely isc stratus make texas fbi government way br steve defense
Topic #17:
scsi bit mac 32 tv fast ide cards ibm chip 16 set difference better bytes fpu faster computer use piece
Topic #18:
edu ftp version pc contact machines available type pub au comments mit anonymous sun mac program unix math looking written
Topic #19:
car cars turkish engine greek oil tires speed turks brake miles greeks 000 better new brakes good dot tire wheel
Topic #20:
god people think jesus edu believe say bible way good know christian point life like church law time faith says
Topic #21:
use using key number time like want used problem idea need know serial example code data traffic application keys case
Topic #22:
university april science 1993 research disease program health information new study medicine power energy computer papers time process development conference
Topic #23:
space years nasa gov new year launch 10 sci pitt gay shuttle km 15 article medical titan soon high 1990
Topic #24:
people said went know going time children think like came home killed happened took armenians come got told away dead
Topic #25:
graphics image mail pub edu aids ray 128 files package mil images 3d send sgi computer systems archive gov format
Topic #26:
windows file problem use edu window thanks files help card know dos like monitor using memory work video program need
Topic #27:
game team play year players season think games hockey player win cubs teams better good baseball ca fan leafs league
Topic #28:
writes com edu article atheism bob jim tek word rights used people news case keith alt said term time given
Topic #29:
government key encryption chip clipper public use keys law people enforcement private nsa security like secure phone com think care

#主題-詞語分佈矩陣
array([[  1.00377390e+02,   3.33333333e-02,   3.33333333e-02, ...,
          3.33333333e-02,   3.33333333e-02,   3.33333333e-02],
       [  3.33333333e-02,   3.33333333e-02,   3.33333333e-02, ...,
          3.33333333e-02,   3.33333333e-02,   3.33333333e-02],
       [  1.13445534e+01,   3.33333333e-02,   1.31402890e+01, ...,
          3.33333333e-02,   3.33333333e-02,   3.33333333e-02],
       ...,
       [  3.33333333e-02,   3.33333333e-02,   3.33333333e-02, ...,
          3.33333333e-02,   9.23349606e+00,   3.33333333e-02],
       [  3.33333333e-02,   3.33333333e-02,   3.33333333e-02, ...,
          3.33333333e-02,   3.33333333e-02,   3.33333333e-02],
       [  3.33333333e-02,   3.33333333e-02,   3.33333333e-02, ...,
          3.33333333e-02,   3.33333333e-02,   3.33333333e-02]])

檢查了一眼每個主題的top words,基本是靠譜的,比如教育類在一起,機械類在一起等等,當然也存在一些問題,比如訓練還不到位,比如沒有進行詞幹化所有”car”“cars”都在Topic #19裡面,大家訓練的時候得避免。

Doc_Topic結果

訓練LDA的一大目的就是分析一篇文章的話題分佈,這才能使得模型創造更高的價值。利用已訓練好的模型將doc轉換為話題分佈的函式及結果如下:

doc_topic_dist = lda.transform(tf)

output:
array([[  0.03333333,   0.03333333,   0.03333333, ...,   0.03333333,
          0.03333333,   0.03333333],
       [  0.03333333,   0.03333333,   0.03333333, ...,   1.9426311 ,
         26.11962169,   0.03333333],
       [  0.03333333,   0.03333333,   0.03333333, ...,   0.03333333,
          0.03333333,   0.03333333],
       ...,
       [  0.03333333,   0.03333333,  15.99360499, ...,   0.03333333,
          0.03333333,   0.03333333],
       [  0.03333333,   0.03333333,   0.03333333, ...,   0.03333333,
          0.03333333,   0.03333333],
       [  0.03333333,   0.03333333,   0.03333333, ...,  13.36262244,
          0.03333333,   0.03333333]])

上文中,我給出了兩篇例文,那兩篇例文的主要話題為:topic#12, topic#20.大家可以自行看一下效果如何。好吧結果可能不太好,原因很多,可能是還沒調參,也可能因為預處理為了節省時間,省去了詞幹化和POS篩選,大家加進去即可。

收斂效果(perplexity)

通過呼叫lda.perplexity(X)函式,可以得知當前訓練的perplexity,sklearn中對perplexity的定義為exp(-1. * log-likelihood per word)

lda.perplexity(tf)

Output: 
1270.5358245980792

本次訓練次數較少,模型還沒收斂,所以perplexity明顯較高,可以通過調參得到更可靠的模型。

(5)(Optional)調參過程

可以調整的引數

  • n_topics: 主題的個數
  • n_features: feature的個數,即常用詞個數
  • doc_topic_prior:即我們的文件主題先驗Dirichlet分佈θd的引數α
  • topic_word_prior:即我們的主題詞先驗Dirichlet分佈βk的引數η
  • learning_method: 即LDA的求解演算法,有’batch’和’online’兩種選擇
  • 其餘sklearn提供的引數:根據LDA求解演算法的不同,存在一些其它引數可以調節,參見最後的附錄:sklearn LDA API 中文解釋。

兩種可行的調參方案

一、以n_topics為例,按照perplexity的大小選擇最佳模型。當然,topic數目的不同勢必會導致perplexity計算的不同,因此perplexity僅能作為參考,topic數目還需要根據實際需求主觀指定。n_topics調參程式碼如下:

n_topics = range(20, 75, 5)
perplexityLst = [1.0]*len(n_topics)

#訓練LDA並列印訓練時間
lda_models = []
for idx, n_topic in enumerate(n_topics):
    lda = LatentDirichletAllocation(n_topics=n_topic,
                                    max_iter=20,
                                    learning_method='batch',
                                    evaluate_every=200,
#                                    perp_tol=0.1, #default                                       
#                                    doc_topic_prior=1/n_topic, #default
#                                    topic_word_prior=1/n_topic, #default
                                    verbose=0)
    t0 = time()
    lda.fit(tf)
    perplexityLst[idx] = lda.perplexity(tf)
    lda_models.append(lda)
    print "# of Topic: %d, " % n_topics[idx],
    print "done in %0.3fs, N_iter %d, " % ((time() - t0), lda.n_iter_),
    print "Perplexity Score %0.3f" % perplexityLst[idx]

#列印最佳模型
best_index = perplexityLst.index(min(perplexityLst))
best_n_topic = n_topics[best_index]
best_model = lda_models[best_index]
print "Best # of Topic: ", best_n_topic

#繪製不同主題數perplexity的不同
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
ax.plot(n_topics, perplexityLst)
ax.set_xlabel("# of topics")
ax.set_ylabel("Approximate Perplexity")
plt.grid(True)
plt.savefig(os.path.join('lda_result', 'perplexityTrend'+CODE+'.png'))
plt.show()

Output:
Best # of Topic:  25
![不同主題數下perplexity趨勢](http://img.blog.csdn.net/20170731171742934?watermark/2/text/aHR0cDovL2Jsb2cuY3Nkbi5uZXQvVGlmZmFueVJhYmJpdA==/font/5a6L5L2T/fontsize/400/fill/I0JBQkFCMA==/dissolve/70/gravity/SouthEast)

二、如果想一次性調整所有引數也可以直接利用sklearn作cv,但是這樣做的結果一定是,耗時十分長。以下程式碼僅供參考,可以根據自身的需求進行增減。

from sklearn.model_selection import GridSearchCV
parameters = {'learning_method':('batch', 'online'), 
              'n_topics':range(20, 75, 5),
              'perp_tol': (0.001, 0.01, 0.1),
              'doc_topic_prior':(0.001, 0.01, 0.05, 0.1, 0.2),
              'topic_word_prior':(0.001, 0.01, 0.05, 0.1, 0.2)
              'max_iter':1000}
lda = LatentDirichletAllocation()
model = GridSearch(lda, parameters)
model.fit(tf)

sorted(model.cv_results_.keys())

附錄:sklearn LDA API 中文解釋

Class sklearn.decomposition.LatentDirichletAllocation(n_topics=10, doc_topic_prior=None, topic_word_prior=None, learning_method=None, learning_decay=0.7, learning_offset=10.0, max_iter=10, batch_size=128, evaluate_every=-1, total_samples=1000000.0, perp_tol=0.1, mean_change_tol=0.001, max_doc_update_iter=100, n_jobs=1, verbose=0, random_state=None)

引數:
1) n_topics: 即我們的隱含主題數K,需要調參。K的大小取決於我們對主題劃分的需求,比如我們只需要類似區分是動物,植物,還是非生物這樣的粗粒度需求,那麼K值可以取的很小,個位數即可。如果我們的目標是類似區分不同的動物以及不同的植物,不同的非生物這樣的細粒度需求,則K值需要取的很大,比如上千上萬。此時要求我們的訓練文件數量要非常的多。
2) doc_topic_prior:即我們的文件主題先驗Dirichlet分佈θd的引數α。一般如果我們沒有主題分佈的先驗知識,可以使用預設值1/K。
3) topic_word_prior:即我們的主題詞先驗Dirichlet分佈βk的引數η。一般如果我們沒有主題分佈的先驗知識,可以使用預設值1/K。
4) learning_method: 即LDA的求解演算法。有 ‘batch’ 和 ‘online’兩種選擇。 ‘batch’即我們在原理篇講的變分推斷EM演算法,而”online”即線上變分推斷EM演算法,在”batch”的基礎上引入了分步訓練,將訓練樣本分批,逐步一批批的用樣本更新主題詞分佈的演算法。預設是”online”。選擇了‘online’則我們可以在訓練時使用partial_fit函式分佈訓練。不過在scikit-learn 0.20版本中預設演算法會改回到”batch”。建議樣本量不大隻是用來學習的話用”batch”比較好,這樣可以少很多引數要調。而樣本太多太大的話,”online”則是首先了。
5)learning_decay:僅僅在演算法使用”online”時有意義,取值最好在(0.5, 1.0],以保證”online”演算法漸進的收斂。主要控制”online”演算法的學習率,預設是0.7。一般不用修改這個引數。
6)learning_offset:僅僅在演算法使用”online”時有意義,取值要大於1。用來減小前面訓練樣本批次對最終模型的影響。
7)max_iter :EM演算法的最大迭代次數。
8)total_samples:僅僅在演算法使用”online”時有意義, 即分步訓練時每一批文件樣本的數量。在使用partial_fit函式時需要。
9)batch_size: 僅僅在演算法使用”online”時有意義, 即每次EM演算法迭代時使用的文件樣本的數量。
10)mean_change_tol :即E步更新變分引數的閾值,所有變分引數更新小於閾值則E步結束,轉入M步。一般不用修改預設值。
11) max_doc_update_iter: 即E步更新變分引數的最大迭代次數,如果E步迭代次數達到閾值,則轉入M步。

方法:
1)fit(X[, y]):利用訓練資料訓練模型,輸入的X為文字詞頻統計矩陣。
2)fit_transform(X[, y]):利用訓練資料訓練模型,並返回訓練資料的主題分佈。
3)get_params([deep]):獲取引數
4)partial_fit(X[, y]):利用小batch資料進行Online方式的模型訓練。
5)perplexity(X[, doc_topic_distr, sub_sampling]):計算X資料的approximate perplexity。
6)score(X[, y]):計算approximate log-likelihood。
7)set_params(**params):設定引數。
8)transform(X):利用已有模型得到語料X中每篇文件的主題分佈。
“`

參考: