1. 程式人生 > >迴圈神經網路系列(六)基於LSTM的唐詩生成

迴圈神經網路系列(六)基於LSTM的唐詩生成


1. 思路

這個示例在很多地方都出現過,對於學習理解LSTM的原理極有幫助,因此我們下面就來一步一步地弄清楚其中的奧祕所在!

對於迴圈神經網路來說,我們首先需要做的仍舊是找到一種將資料序列化的方法。當然,對於古詩詞來說,每個字的出現順序就是天然的一個序列,因此我們就可以直接按照這個序列來處理。並且一首古詩詞可以看成是一個樣本(為了敘述方便,我們下面僅以一首詩的第一句為例子),例如[[床前明月光],[小時不識月]]為兩個樣本。

1.1 網路訓練模型

現在暫時假設我們的訓練集中只有兩個樣本x=[[床前明月光],[小時不識月]],那麼想想此時對應的標籤應該是什麼?回想一下,我們的目的是利用迴圈網路來寫詩,也就是說當我們把模型訓練好了之後,直接餵給模型第一個字,它就能寫出一句(或一首)詩了;既然如此,那麼我們的訓練過程就應該是學習每首古詩中,所有字的一個出現順序。所以,對於床前明月光

這個樣本來說,其對應的標籤值就應該是前明月光光。由此可知,我們的網路模型就應該大致長這樣:

接下來,為了能更清晰的敘述網路結構而不被其它因素影響,我們在這個小節中就直接用一個漢字來表示一個維度(實際中要將一個漢字轉為n維的詞向量)。此時,訓練樣本的維度就應該是shape=[2,5,1]。在這個示例中,我們採用了兩層的LSTM網路外加一個softmax的全連線層,並且LSTM網路的輸出維度output_size=32,於是我們就可以畫出下面這個網路示意圖:

從圖中可以看到,第一步:我們是將shape = [2,5,1]的訓練集餵給LSTM網路,然後從網路得到shape=[2,5,32]

的輸出;第二步:我們將LSTM網路得到的輸出reshape成[10,32]的矩陣;第三步:再將上一步的結果餵給最後一個softmax全連線網路,這樣就能完成對於每個字的分類任務了。

對於第二步為什麼要reshape然後餵給第三部的全連線網路,我們可以這樣想:假如是一個樣本的話,那麼LSTM的輸出大小就為[1,5,32],也就是說第一步喂進去的每一個字通過LSTM這個網路處理之後都變成一個[1,32]的向量化表示方式,只是第2個字保留了第一個字裡面的資訊,第3個字保留了跟前面的資訊等等。這也就有點類似於卷積網路中先用卷積層對圖片進行特徵提取,然後再做一個分類處理。於是乎我們就可以發現,其實LSTM網路的本質也是在做一個特徵提取的工作,區別於卷積網路的就是:卷積網路提取的是基於空間上的特徵,而迴圈網路提取的是基於時間序列上的特徵。至於最後以層,該分類就分類處理,該回歸就回歸處理。

1.2 網路預測模型

當網路經過訓練完成後就可以拿來預測了,只不過在預測的時候我們餵給網路的就只是一個字了;然後用當前預測得到的字作為下一個字;如下圖所示:

1.3 資料處理

經過上面的講解,我們大致明白了基於LSTM網路古詩生成原理:先用LSTM做特徵提取,然後分類。既然最後我們要完成的是一個分類任務,那麼我們不得不做的就是將所有的類別給整理處理,也就是所有的資料集中一共包含了多少個不同的字,因為我們來做的就是根據上一個字預測下一個字。

同時由於我們處理的是文字資訊,因此我們需要將每個字都採用詞(字)向量的形式表示,由於沒有現成的詞向量,所有我們要再LSTM的前面假加入一個詞嵌入層。

最後,為了避免最終的分類數過於龐大,可以選擇去掉出現頻率較小的字,比如可以去掉只出現過一次的字。

總結一下資料預處理的步驟:

  • 1.統計出所有不同的字,並做成一個字典;
  • 2.對於每首詩,將每個字、標點都轉換為字典中對應的編號,構成X;
  • 3.將X整體左移動以為構成Y

2. 程式碼講解

在此首先感謝Github上的jinfaganghzy46這兩位作者,因為整體程式碼都是參照的他們的,加了一點點自己的元素。

2.1 資料預處理

先來看看原始的資料集長什麼樣:

首春:寒隨窮律變,春逐鳥聲開。初風飄帶柳,晚雪間花梅。碧林青舊竹,綠沼翠新苔。芝田初雁去,綺樹巧鶯來。
初晴落景:晚霞聊自怡,初晴彌可喜。日晃百花色,風動千林翠。池魚躍不同,園鳥聲還異。寄言博通者,知予物外志。

而我們需要得到的是類似於這樣的:

X:
[[1,4,6,3,2,5,3,0,0,0],
 [5,6,4,3,9,1,0,0,0,0]]
 
Y:
[[4,6,3,2,5,3,0,0,0,0],
 [6,4,3,9,1,0,0,0,0,0]]

其中的0表示,我設定了一首詩的最大長度,如果不足就補0(因為每首詩的長度不一樣);而其它的數字則表示詩中每個字以及標點在字典中的索引。同時,為了後面的生成詩時候的轉換,我們還需要得到字典。

而這隻需要tensorflow中的幾行程式碼就能搞定(友情提示:在統計詞頻使用Counter()這個類時,對於同一詞頻的詞在字典中的排列順序window平臺和linux平臺的處理結果不一樣)。以下只是部分程式碼,完整參見原始碼中的data_helper.py模組

    vocab_processor = VocabularyProcessor(max_document_length=max_length,min_frequency=5)
    x = np.array(list(vocab_processor.fit_transform(poems)))
    dictionary = vocab_processor.vocabulary_.__dict__.copy()
    fre = dictionary['_freq']
    # print(sorted(fre.items(), key=lambda x: x[1], reverse=True))
    word_to_int = dictionary['_mapping']
    int_to_word = dictionary['_reverse_mapping']
    np.random.seed(50)
    shuffle_index = np.random.permutation(x.shape[0])
    shuffle_x = x[shuffle_index]
    shuffle_y = np.copy(shuffle_x)
    shuffle_y[:, :-1] = shuffle_x[:, 1:]

2.2 網路構建

在整個網路構建中,主要分成了四個部分build_input(),build_rnn(),ttrain(),compose_poem()。下面就挑重點的說。

2.2.1 build_input()

由2.1節可知,我們預處理後得到資料的形式是二維的,所以在定義placeholder也要是二維的;同時,由於要採用詞向量進行表示,所以此處還要加入一個詞嵌入層。程式碼如下:


with tf.name_scope('model_inputs'):
    self.inputs = tf.placeholder(dtype=tf.int32, shape=[self.batch_size, None], name='input-x')
    self.targets = tf.placeholder(dtype=tf.int64, shape=[self.batch_size, None], name='input-y')
with tf.name_scope('embedding_layer'):
    self.embedding = tf.Variable(tf.truncated_normal(shape=[self.num_class, self.embedding_size], stddev=0.1),name='embedding')
    self.model_inputs = tf.nn.embedding_lookup(self.embedding,
                                               self.inputs)

由於我們訓練時inputs的第二個維度為詩的長度,預測時為1,所以就寫成了None

2.2.1 build_rnn()

with tf.name_scope('build_rnn_model'):
    cell = tf.nn.rnn_cell.MultiRNNCell(
        [get_a_cell(self.rnn_size) for _ in range(self.num_layer)])  # 搭建num_layer層的模型
    self.initial_state = cell.zero_state(batch_size=self.batch_size, dtype=tf.float32)
    self.outputs, self.final_state = tf.nn.dynamic_rnn(cell, inputs=self.model_inputs,
                                                       initial_state=self.initial_state)
    output = tf.reshape(self.outputs, [-1, self.rnn_size])

第7行程式碼就是圖(id:p0033)中的第二步。接下來就是一個全連線:

with tf.name_scope('full_connection'):
    weights = tf.Variable(tf.truncated_normal(shape=[self.rnn_size, self.num_class]),
                          name='weights')  # [128,5000]
    bias = tf.Variable(tf.zeros(shape=[self.num_class]), name='bias')
    self.logits = tf.nn.xw_plus_b(output, weights, bias, name='logits')

構造損失:

with tf.name_scope('loss'):
    labels = tf.reshape(self.targets, [-1])
    loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=self.logits)
    self.loss = tf.reduce_mean(loss)

預測值和準確率

with tf.name_scope('accuracy'):
    self.proba_prediction = tf.nn.softmax(self.logits, name='output_probability')
    self.prediction = tf.argmax(self.proba_prediction, axis=1, name='output_prediction')
    correct_predictions = tf.equal(self.prediction, labels)
    self.accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32))

原始碼地址