1. 程式人生 > >基於tflearn使用lstm實現文字分類

基於tflearn使用lstm實現文字分類

模型訓練部分程式碼

# -*- coding: utf-8 -*-
from __future__ import division, print_function, absolute_import
import tflearn
import os
import numpy
import jieba
import sys
import random
import re
import fire
from sys import argv
import json
from tflearn.data_utils import to_categorical, pad_sequences
from tflearn.datasets import 
imdb def load_data1( keywordPath,stopwords_set,filepath,dictfilepath,n_words, valid_portion=0.1, sort_by_len=True): #keywordPath = sys.argv[1] jieba.load_userdict(keywordPath) pathDir = os.listdir(filepath) data_set = [] train_set_x = [] train_set_y = [] test_set_x = [] test_set_y = [] # 把停用詞做成字典
stopwords = {} fstop = open(stopwords_set, 'rb') for eachWord in fstop: stopwords[eachWord.strip().decode('utf-8', 'ignore')] = eachWord.strip().decode('utf-8', 'ignore') fstop.close() #寫入詞典 f1 = open(dictfilepath, 'w', encoding='UTF-8') dic = dict() i = 0
j = 0 #構建詞表 for allDir in pathDir: child = filepath + allDir if os.path.isdir(child): pathSubDir = os.listdir(child) k = 1 for subDir in pathSubDir: # if m >5000: # break des = child + os.sep + subDir s1 = "" invert = [] fOpen = open(des, "r", encoding='UTF-8') for eachLine in fOpen: line = eachLine.strip() line1 = re.sub("[\s+\.\!\/_,$%^*()?;;:-【】+\"\']+|[+——!,;:。?、[email protected]#¥%……&*()]+", "", line) wordList = list(jieba.cut(line1)) for word in wordList: if word not in stopwords: data_set.append(word) if word not in dic: i = i + 1 dic[word] = i invert.append(dic[word])#append到invertlist,invert[22,123,424,..],文件word編碼集合 if re.match('[^ \t\n\x0B\f\r]', word, flags=0): f1.write(word+" "+str(i)) f1.write("\n") else: invert.append(dic[word]) j = j+1 #if random.randint(1, 10) == 1:#false n = len(pathSubDir) if k <= n*0.1: print(str(j)+" test "+allDir) test_set_x.append(invert) test_set_y.append(allDir) else: print(str(j) + " train " + allDir) train_set_x.append(invert) train_set_y.append(allDir) k += 1 fOpen.close() f1.close() print("the number of words : "+str(i)) n_samples = len(train_set_x) sidx = numpy.random.permutation(n_samples) n_train = int(numpy.round(n_samples * (1. - valid_portion))) valid_set_x = [train_set_x[s] for s in sidx[n_train:]] valid_set_y = [train_set_y[s] for s in sidx[n_train:]] train_set_x = [train_set_x[s] for s in sidx[:n_train]] train_set_y = [train_set_y[s] for s in sidx[:n_train]] train_set = (train_set_x, train_set_y) valid_set = (valid_set_x, valid_set_y) def remove_unk(x): return [[1 if w >= n_words else w for w in sen] for sen in x] valid_set_x, valid_set_y = valid_set train_set_x, train_set_y = train_set train_set_x = remove_unk(train_set_x) valid_set_x = remove_unk(valid_set_x) test_set_x = remove_unk(test_set_x) def len_argsort(seq): return sorted(range(len(seq)), key=lambda x: len(seq[x])) if sort_by_len: sorted_index = len_argsort(test_set_x) test_set_x = [test_set_x[i] for i in sorted_index] test_set_y = [test_set_y[i] for i in sorted_index] sorted_index = len_argsort(valid_set_x) valid_set_x = [valid_set_x[i] for i in sorted_index] valid_set_y = [valid_set_y[i] for i in sorted_index] sorted_index = len_argsort(train_set_x) train_set_x = [train_set_x[i] for i in sorted_index] train_set_y = [train_set_y[i] for i in sorted_index] train = (train_set_x, train_set_y) valid = (valid_set_x, valid_set_y) test = (test_set_x, test_set_y) return train, valid, test def train(): print("#######################") print("# train #") print("#######################") words = [] s = os.sep # 更改路徑操作符 keywordPath = sys.argv[1] dictPath = sys.argv[2] f = open(dictPath, "r", encoding="utf-8") for i in f: words.append(i) word_num = len(words) modelPath = sys.argv[3] stopword_setPath = sys.argv[4] classnum = int(sys.argv[5]) dataPath = "d:" + s + "data" train, valid, test = load_data1(keywordPath=keywordPath, stopwords_set=stopword_setPath, filepath=dataPath, dictfilepath=dictPath, n_words=word_num, valid_portion=0.1) trainX, trainY = train valX, valY = valid trainX = pad_sequences(trainX, maxlen=30, value=0.) valX = pad_sequences(valX, maxlen=30, value=0.) trainY = to_categorical(trainY, nb_classes=classnum) valY = to_categorical(valY, nb_classes=classnum) net = tflearn.input_data([None, 30]) net = tflearn.embedding(net, input_dim=word_num, output_dim=128) net = tflearn.lstm(net, 128, dropout=0.8) net = tflearn.fully_connected(net, classnum, activation='softmax') net = tflearn.regression(net, optimizer='adam', learning_rate=0.01, loss='categorical_crossentropy') model = tflearn.DNN(net, tensorboard_verbose=0) model.fit(trainX, trainY, n_epoch=1, validation_set=(valX, valY), show_metric=True, batch_size=256) model.save(modelPath) if __name__ == '__main__': fire.Fire(train)