1. 程式人生 > >CNN字元級中文文字分類-基於TensorFlow實現

CNN字元級中文文字分類-基於TensorFlow實現

本章旨在使用TensorFlow API實現卷積神經網路文字分類。

如今,TensorFlow大版本已經升級到了1.3,對很多的網路層實現了更高層次的封裝和實現,甚至還整合瞭如Keras這樣優秀的一些高層次框架,使得其易用性大大提升。相比早起的底層程式碼,如今的實現更加簡潔和優雅。

本章的目的是基於TensorFlow的API來重新實現一個在中文文字上的分類器。如果你覺得對你有些許幫助或者疑惑,歡迎star和交流。

資料集

本文采用了清華NLP組提供的THUCNews新聞文字分類資料集的一個子集(原始的資料集大約74萬篇文件,訓練起來需要花較長的時間)。資料集請自行到

THUCTC:一個高效的中文文字分類工具包下載,請遵循資料提供方的開源協議。

本次訓練使用了其中的10個分類,每個分類6500條,總共65000條新聞資料。

類別如下:

體育, 財經, 房產, 家居, 教育, 科技, 時尚, 時政, 遊戲, 娛樂

資料集劃分如下:

  • 訓練集: 5000*10
  • 驗證集: 500*10
  • 測試集: 1000*10

從原資料集生成子集的過程請參看helper下的兩個指令碼。其中,copy_data.sh用於從每個分類拷貝6500個檔案,cnews_group.py用於將多個檔案整合到一個檔案中。執行該檔案後,得到三個資料檔案:

  • cnews.train.txt: 訓練集(50000條)
  • cnews.val.txt: 驗證集(5000條)
  • cnews.test.txt: 測試集(10000條)

預處理

data/cnews_loader.py為資料的預處理檔案。

  • read_file(): 讀取檔案資料;
  • build_vocab(): 構建詞彙表,使用字元級的表示,這一函式會將詞彙表儲存下來,避免每一次重複處理;
  • read_vocab(): 讀取上一步儲存的詞彙表,轉換為{詞:id}表示;
  • read_category(): 將分類目錄固定,轉換為{類別: id}表示;
  • to_words(): 將一條由id表示的資料重新轉換為文字;
  • preocess_file(): 將資料集從文字轉換為固定長度的id序列表示;
  • batch_iter()
    : 為神經網路的訓練準備經過shuffle的批次的資料。

經過資料預處理,資料的格式如下:

Data Shape Data Shape
x_train [50000, 600] y_train [50000, 10]
x_val [5000, 600] y_val [5000, 10]
x_test [10000, 600] y_test [10000, 10]

CNN卷積神經網路

配置項

CNN可配置的引數如下所示,在cnn_model.py中。

class TCNNConfig(object):
    """CNN配置引數"""

    embedding_dim = 64      # 詞向量維度
    seq_length = 600        # 序列長度
    num_classes = 10        # 類別數
    num_filters = 128        # 卷積核數目
    kernel_size = 5         # 卷積核尺寸
    vocab_size = 5000       # 詞彙表達小

    hidden_dim = 128        # 全連線層神經元

    dropout_keep_prob = 0.5 # dropout保留比例
    learning_rate = 1e-3    # 學習率

    batch_size = 64         # 每批訓練大小
    num_epochs = 10         # 總迭代輪次

    print_per_batch = 100    # 每多少輪輸出一次結果
    save_per_batch = 10      # 每多少輪存入tensorboard

CNN模型

具體參看cnn_model.py的實現。

大致結構如下:

images_cnn_architecture

訓練與驗證

執行 python run_cnn.py train,可以開始訓練。

若之前進行過訓練,請把tensorboard/textcnn刪除,避免TensorBoard多次訓練結果重疊。

Configuring CNN model...
Configuring TensorBoard and Saver...
Loading training and validation data...
Time usage: 0:00:14
Training and evaluating...
Epoch: 1
Iter:      0, Train Loss:    2.3, Train Acc:  10.94%, Val Loss:    2.3, Val Acc:   8.92%, Time: 0:00:01 *
Iter:    100, Train Loss:   0.88, Train Acc:  73.44%, Val Loss:    1.2, Val Acc:  68.46%, Time: 0:00:04 *
Iter:    200, Train Loss:   0.38, Train Acc:  92.19%, Val Loss:   0.75, Val Acc:  77.32%, Time: 0:00:07 *
Iter:    300, Train Loss:   0.22, Train Acc:  92.19%, Val Loss:   0.46, Val Acc:  87.08%, Time: 0:00:09 *
Iter:    400, Train Loss:   0.24, Train Acc:  90.62%, Val Loss:    0.4, Val Acc:  88.62%, Time: 0:00:12 *
Iter:    500, Train Loss:   0.16, Train Acc:  96.88%, Val Loss:   0.36, Val Acc:  90.38%, Time: 0:00:15 *
Iter:    600, Train Loss:  0.084, Train Acc:  96.88%, Val Loss:   0.35, Val Acc:  91.36%, Time: 0:00:17 *
Iter:    700, Train Loss:   0.21, Train Acc:  93.75%, Val Loss:   0.26, Val Acc:  92.58%, Time: 0:00:20 *
Epoch: 2
Iter:    800, Train Loss:   0.07, Train Acc:  98.44%, Val Loss:   0.24, Val Acc:  94.12%, Time: 0:00:23 *
Iter:    900, Train Loss:  0.092, Train Acc:  96.88%, Val Loss:   0.27, Val Acc:  92.86%, Time: 0:00:25
Iter:   1000, Train Loss:   0.17, Train Acc:  95.31%, Val Loss:   0.28, Val Acc:  92.82%, Time: 0:00:28
Iter:   1100, Train Loss:    0.2, Train Acc:  93.75%, Val Loss:   0.23, Val Acc:  93.26%, Time: 0:00:31
Iter:   1200, Train Loss:  0.081, Train Acc:  98.44%, Val Loss:   0.25, Val Acc:  92.96%, Time: 0:00:33
Iter:   1300, Train Loss:  0.052, Train Acc: 100.00%, Val Loss:   0.24, Val Acc:  93.58%, Time: 0:00:36
Iter:   1400, Train Loss:    0.1, Train Acc:  95.31%, Val Loss:   0.22, Val Acc:  94.12%, Time: 0:00:39
Iter:   1500, Train Loss:   0.12, Train Acc:  98.44%, Val Loss:   0.23, Val Acc:  93.58%, Time: 0:00:41
Epoch: 3
Iter:   1600, Train Loss:    0.1, Train Acc:  96.88%, Val Loss:   0.26, Val Acc:  92.34%, Time: 0:00:44
Iter:   1700, Train Loss:  0.018, Train Acc: 100.00%, Val Loss:   0.22, Val Acc:  93.46%, Time: 0:00:47
Iter:   1800, Train Loss:  0.036, Train Acc: 100.00%, Val Loss:   0.28, Val Acc:  92.72%, Time: 0:00:50
No optimization for a long time, auto-stopping...

在驗證集上的最佳效果為94.12%,且只經過了3輪迭代就已經停止。

準確率和誤差如圖所示:

images

測試

執行 python run_cnn.py test 在測試集上進行測試。

Configuring CNN model...
Loading test data...
Testing...
Test Loss:   0.14, Test Acc:  96.04%
Precision, Recall and F1-Score...
             precision    recall  f1-score   support

         體育       0.99      0.99      0.99      1000
         財經       0.96      0.99      0.97      1000
         房產       1.00      1.00      1.00      1000
         家居       0.95      0.91      0.93      1000
         教育       0.95      0.89      0.92      1000
         科技       0.94      0.97      0.95      1000
         時尚       0.95      0.97      0.96      1000
         時政       0.94      0.94      0.94      1000
         遊戲       0.97      0.96      0.97      1000
         娛樂       0.95      0.98      0.97      1000

avg / total       0.96      0.96      0.96     10000

Confusion Matrix...
[[991   0   0   0   2   1   0   4   1   1]
 [  0 992   0   0   2   1   0   5   0   0]
 [  0   1 996   0   1   1   0   0   0   1]
 [  0  14   0 912   7  15   9  29   3  11]
 [  2   9   0  12 892  22  18  21  10  14]
 [  0   0   0  10   1 968   4   3  12   2]
 [  1   0   0   9   4   4 971   0   2   9]
 [  1  16   0   4  18  12   1 941   1   6]
 [  2   4   1   5   4   5  10   1 962   6]
 [  1   0   1   6   4   3   5   0   1 979]]
Time usage: 0:00:05

在測試集上的準確率達到了96.04%,且各類的precision, recall和f1-score都超過了0.9。

從混淆矩陣也可以看出分類效果非常優秀。