1. 程式人生 > >LDA模型數據的可視化

LDA模型數據的可視化

好的 strip pan remove 從大到小 ems open 可視化 except

 1 """
 2     執行lda2vec.ipnb中的代碼
 3     模型LDA
 4     功能:訓練好後模型數據的可視化
 5 """
 6 
 7 from lda2vec import preprocess, Corpus
 8 import matplotlib.pyplot as plt
 9 import numpy as np
10 # %matplotlib inline
11 import pyLDAvis
12 try:
13     import seaborn
14 except:
15     pass
16 # 加載訓練好的主題-文檔模型,這裏是查看數據使用。這裏需要搞清楚數據的形式,還要去回看這個文件是怎麽構成的
17 npz = np.load(open(D:/my_AI/lda2vec-master/examples/twenty_newsgroups/lda2vec/topics.pyldavis.npz, rb)) 18 # 數據 19 dat = {k: v for (k, v) in npz.iteritems()} 20 # 詞匯表變成list 21 dat[vocab] = dat[vocab].tolist() 22 23 ##################################### 24 ## 主題-詞匯 25 #####################################
26 # 主題個數為10 27 top_n = 10 28 # 主題對應10個最相關的詞 29 topic_to_topwords = {} 30 for j, topic_to_word in enumerate(dat[topic_term_dists]): 31 top = np.argsort(topic_to_word)[::-1][:top_n] # 概率從大到小的下標索引值 32 msg = Topic %i % j 33 # 通過list的下標獲取關鍵詞 34 top_words = [dat[vocab][i].strip()[:35] for
i in top] 35 # 數據拼接 36 msg += .join(top_words) 37 print(msg) 38 # 將數據保存到字典裏面 39 topic_to_topwords[j] = top_words 40 41 import warnings 42 warnings.filterwarnings(ignore) 43 prepared_data = pyLDAvis.prepare(dat[topic_term_dists], dat[doc_topic_dists], 44 dat[doc_lengths] * 1.0, dat[vocab], dat[term_frequency] * 1.0, mds=tsne) 45 46 from sklearn.datasets import fetch_20newsgroups 47 remove=(headers, footers, quotes) 48 texts = fetch_20newsgroups(subset=train, remove=remove).data 49 50 51 ############################################## 52 ## 選取一篇文章,確定該文章有哪些主題 53 ############################################## 54 55 print(texts[1]) 56 tt = dat[doc_topic_dists][1] 57 msg = "{weight:02d}% in topic {topic_id:02d} which has top words {text:s}" 58 # 遍歷這20個主題,觀察一下它的權重,權重符合的跳出來 59 for topic_id, weight in enumerate(dat[doc_topic_dists][1]): 60 if weight > 0.01: 61 # 權重符合要求,那麽輸出該主題下的關聯詞匯 62 text = , .join(topic_to_topwords[topic_id]) 63 print (msg.format(topic_id=topic_id, weight=int(weight * 100.0), text=text)) 64 65 # plt.bar(np.arange(20), dat[‘doc_topic_dists‘][1]) 66 67 print(texts[51]) 68 tt = texts[51] 69 msg = "{weight:02d}% in topic {topic_id:02d} which has top words {text:s}" 70 for topic_id, weight in enumerate(dat[doc_topic_dists][51]): 71 if weight > 0.01: 72 text = , .join(topic_to_topwords[topic_id]) 73 print(msg.format(topic_id=topic_id, weight=int(weight * 100.0), text=text)) 74 75 76 # plt.bar(np.arange(20), dat[‘doc_topic_dists‘][51])

LDA模型數據的可視化