1. 程式人生 > >迴圈神經網路系列(三)Tensorflow中MultiRNNCell

迴圈神經網路系列(三)Tensorflow中MultiRNNCell

迴圈神經網路系列(一) Tensorflow中BasicRNNCell
迴圈神經網路系列(二)Tensorflow中dynamic_rnn

經過前面兩篇博文,我們介紹瞭如何定義一個RNN單元,以及用dynamic_rnn來對其在時間維度(橫軸)上展開。我們今天要介紹的就是如何疊加多層RNN單元(如雙向LSTM),同時對其按時間維度展開。具體多層RNN展開長什麼樣呢?還是用最直觀的圖來展示,如下所示:

其中A,B分別表示兩個RNN單元,然後再分別對其按時間維度time_step=3進行展開,最終形成了兩層,包含兩個狀態和3個輸出。要完成這樣一個例子,在Tensorflow中該如何來實現呢?

1. 先定義兩個RNN單元

def get_a_cell(output_size):
    return tf.nn.rnn_cell.BasicRNNCell(num_units=output_size)
    
output_size = 5
batch_size = 4
time_step = 3
dim = 3
cell = tf.nn.rnn_cell.MultiRNNCell([get_a_cell(output_size) for _ in range(2)])

經過上面的8行程式碼,我們就定義好了兩個堆疊在一起的RNN單元A和B,如下圖所示:

2. 利用dynamic_rnn進行展開

import tensorflow as tf


def get_a_cell(output_size):
    return tf.nn.rnn_cell.BasicRNNCell(num_units=output_size)


output_size = 5
batch_size = 4
time_step = 3
dim = 3
cell = tf.nn.rnn_cell.MultiRNNCell([get_a_cell(output_size) for _ in range(2)])

inputs = tf.placeholder(
dtype=tf.float32, shape=[time_step, batch_size, dim]) h0 = cell.zero_state(batch_size=batch_size, dtype=tf.float32) outputs, final_state = tf.nn.dynamic_rnn(cell, inputs, initial_state=h0, time_major=True) print(outputs) print(final_state) >> Tensor("rnn/TensorArrayStack/TensorArrayGatherV3:0", shape=(3, 4, 5), dtype=float32) (<tf.Tensor 'rnn/while/Exit_2:0' shape=(4, 5) dtype=float32>, <tf.Tensor 'rnn/while/Exit_3:0' shape=(4, 5) dtype=float32>)

從第23行結果可知,輸出的最後狀態有兩個,形狀分別都是shape=(4,5),這也符合我們的預期;而第22行的輸出結果shape=(3,4,5)有表示什麼意思呢?這裡的3就不表示維度了,而表示輸出結果有3部分,每個部分的大小都是shape=(4,5),這也是我們所預期的。並且B層的final_state應該使等於第三個輸出的。

3. 喂個例項跑跑

import tensorflow as tf
import numpy as np


def get_a_cell(output_size):
    return tf.nn.rnn_cell.BasicRNNCell(num_units=output_size)


output_size = 5
batch_size = 4
time_step = 3
dim = 3
cell = tf.nn.rnn_cell.MultiRNNCell([get_a_cell(output_size) for _ in range(2)])

inputs = tf.placeholder(dtype=tf.float32, shape=[time_step, batch_size, dim])
h0 = cell.zero_state(batch_size=batch_size, dtype=tf.float32)
outputs, final_state = tf.nn.dynamic_rnn(cell, inputs, initial_state=h0, time_major=True)
print(outputs)
print(final_state)

X = np.array([[[1, 2, 1], [2, 0, 0], [2, 1, 0], [1, 1, 0]],  # x1
              [[1, 2, 1], [2, 0, 0], [2, 1, 0], [1, 1, 0]],  # x2
              [[1, 2, 1], [2, 0, 0], [2, 1, 0], [1, 1, 0]]])  # x3
sess = tf.Session()
sess.run(tf.global_variables_initializer())
a, b = sess.run([outputs, final_state], feed_dict={inputs: X})
print('outputs:')
print(a)
print('final_state:')
print(b)




>>
Tensor("rnn/TensorArrayStack/TensorArrayGatherV3:0", shape=(3, 4, 5), dtype=float32)
(<tf.Tensor 'rnn/while/Exit_2:0' shape=(4, 5) dtype=float32>, <tf.Tensor 'rnn/while/Exit_3:0' shape=(4, 5) dtype=float32>)


outputs:
[[[-0.6958626  -0.6776572   0.15731043 -0.6311886   0.20267256]
  [ 0.07732188  0.09182965 -0.49770945  0.0051106   0.23445603]
  [-0.304461   -0.2706095  -0.4083268  -0.3364025   0.26729658]
  [-0.38100582 -0.35050285 -0.2153194  -0.3686508   0.21973696]]

 [[-0.38028494 -0.39984316  0.5924934  -0.7433707   0.45858386]
  [ 0.15477817  0.06120307 -0.23038468 -0.2532196   0.19319542]
  [-0.09605556 -0.23243633  0.18608333 -0.6444844   0.34893066]
  [-0.15772797 -0.2529126   0.32016686 -0.6125384   0.33331177]]

 [[-0.45718285 -0.20688602  0.66812176 -0.81284994 -0.03955056]
  [ 0.16529301  0.2245452  -0.45850635 -0.36383444  0.18540041]
  [-0.0918629   0.11388774  0.01027385 -0.7402484   0.06189062]
  [-0.21528585  0.00840321  0.20390712 -0.71303254  0.04809263]]]
final_state:
(array([[ 0.01885682,  0.79334605, -0.99330646, -0.19715786,  0.8772415 ],
       [-0.43402836, -0.2537776 , -0.52755517,  0.5360404 , -0.38291538],
       [-0.49418357,  0.28655267, -0.91146743,  0.4856847 ,  0.22705963],
       [-0.3087254 ,  0.42241457, -0.8743213 ,  0.26078507,  0.3464944 ]],
      dtype=float32), 
array([[-0.45718285, -0.20688602,  0.66812176, -0.81284994, -0.03955056],
       [ 0.16529301,  0.2245452 , -0.45850635, -0.36383444,  0.18540041],
       [-0.0918629 ,  0.11388774,  0.01027385, -0.7402484 ,  0.06189062],
       [-0.21528585,  0.00840321,  0.20390712, -0.71303254,  0.04809263]],
      dtype=float32))

可以看到output有3個部分,final_state有2個部分,且output的第三個結果和final_state的第二個結果相同,符合我們上面的猜想。

注意:

如果每層的輸出大小要不同的話,直接在定義多層單元的時候填上不同的引數即可!

output_size = [5, 6]
cell = tf.nn.rnn_cell.MultiRNNCell([get_a_cell(size) for size in output_size])