1. 程式人生 > >迴圈神經網路系列(二)Tensorflow中dynamic_rnn

迴圈神經網路系列(二)Tensorflow中dynamic_rnn

1.回顧

上一篇博文(迴圈神經網路系列(一)Tensorflow中BasicRNNCell)中我們介紹了在Tensoflow中,每個RNN單元的實現,以及對應各個引數的含義。自那之後,我們就能通過Tensorflow實現一個單元的計算了。

import tensorflow as tf
import numpy as np

x = np.array([[1, 0, 1, 2], [2, 1, 1, 1]])
X = tf.placeholder(dtype=tf.float32, shape=[2, 4], name='input')
cell = tf.nn.rnn_cell.
BasicRNNCell(num_units=5) # output_size:10,也可以換成GRUCell,LSTMAACell,BasicRNNCell h0 = cell.zero_state(batch_size=2, dtype=tf.float32) # batch_size:2 output, h1 = cell.call(X, h0) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) a, b = (sess.run([output, h1], feed_dict={X:
x})) print('output:') print(a) print('h1:') print(b) >> output: [[ 0.4495004 0.9573416 0.6013933 0.75571895 -0.8172958 ] [ 0.6624889 0.7011481 0.68771356 0.77796507 -0.7617092 ]] h1: [[ 0.4495004 0.9573416 0.6013933 0.75571895 -0.8172958 ] [ 0.6624889 0.7011481 0.68771356 0.77796507
-0.7617092 ]]

通過以上的程式碼,我們完成了如下操作:
在這裡插入圖片描述

但是通常情況下,我們都是要進行這樣的操作:

在這裡插入圖片描述

輸入 h 0 , x 1 h_0,x_1 得到 o u t p u t 1 , h 1 output_1,h_1 ;然後輸入 h 1 , x 2 h_1,x_2 得到 o u t p u t 2 , h 2 output_2,h_2 ;接著再輸入 h 2 , x 3 h_2,x_3 得到 o u t p u t 3 , h 3 output_3,h_3 以此類推。那麼如何通過Tensorflow一步實現呢?

2. dynamic_rnn

為了實現一步計算多次,我們就要用到Tensorflow中的dynamic_rnn(),程式碼如下(實現了上圖列出的三步
):

import tensorflow as tf
import numpy as np
from tensorflow.python.ops import variable_scope as vs

output_size = 5
batch_size = 4
time_step = 3
dim = 3
cell = tf.nn.rnn_cell.BasicRNNCell(num_units=output_size)
inputs = tf.placeholder(dtype=tf.float32, shape=[time_step, batch_size, dim])
h0 = cell.zero_state(batch_size=batch_size, dtype=tf.float32)
X = np.array([[[1, 2, 1], [2, 0, 0], [2, 1, 0], [1, 1, 0]],  # x1
              [[1, 2, 1], [2, 0, 0], [2, 1, 0], [1, 1, 0]],  # x2
              [[1, 2, 1], [2, 0, 0], [2, 1, 0], [1, 1, 0]]])  # x3
outputs, final_state = tf.nn.dynamic_rnn(cell, inputs, initial_state=h0, time_major=True)

sess = tf.Session()
sess.run(tf.global_variables_initializer())
a, b = sess.run([outputs, final_state], feed_dict={inputs:X})
print(a)
print(b)

其中第7行time_step=3就表示計算三步,所以輸入X就對應有三個部分。再最終的輸出結果中,outputs裡包含了( o u t p u t s 1 , o u t p u t s 2 , o u t p u t s 3 ) outputs_1,outputs_2,outputs_3) ,而final_stat就只是 h 3 h_3 ,並且 o u t p u t s 3 , h 3 outputs_3,h_3 是相等的。

結果:

outputs:
[[[ 0.9427065  -0.92617476 -0.79179853  0.6308035   0.07298201]
  [ 0.7051633  -0.62077284 -0.79618317  0.5004738  -0.20110159]
  [ 0.85066974 -0.77197933 -0.76875883  0.80251306 -0.04951192]
  [ 0.67497337 -0.57974416 -0.4408107   0.68083197  0.05233984]]# output1

 [[ 0.9828192  -0.9433205  -0.9233751   0.72930676 -0.34445292]
  [ 0.92153275 -0.58029604 -0.8949743   0.5431045  -0.46945637]
  [ 0.9690989  -0.7922626  -0.8973758   0.81312704 -0.46288016]
  [ 0.88565385 -0.6617377  -0.68075943  0.70066273 -0.34827012]]# output2

 [[ 0.99172366 -0.93298715 -0.9272905   0.7158564  -0.46278387]
  [ 0.9566409  -0.5595625  -0.9101479   0.58005375 -0.5905321 ]
  [ 0.9838727  -0.7693646  -0.91019756  0.82892674 -0.58026373]
  [ 0.9438508  -0.61732507 -0.7356022   0.73460865 -0.483655  ]]]# output3
final_state:
[[ 0.99172366 -0.93298715 -0.9272905   0.7158564  -0.46278387]
 [ 0.9566409  -0.5595625  -0.9101479   0.58005375 -0.5905321 ]
 [ 0.9838727  -0.7693646  -0.91019756  0.82892674 -0.58026373]
 [ 0.9438508  -0.61732507 -0.7356022   0.73460865 -0.483655  ]]# final_satae

3.總結

當使用dynamic_rnn時,對於輸入資料的格式有兩種:

第一種:輸入格式為[batch_size,time_steps,input_size],此時得到的輸出output的形狀為[batch_size,time_steps,output_size],final_state的形狀為[batch_size,state_size]

第二種:也就是我們上面用到的,此時的輸入格式為[time_steps,batch_size,input_size],得到的輸出output的形狀為[time_steps,batch_size,output_size],final_state的形狀仍然為[batch_size,state_size],但此時要指定time_major = True

對比這兩種輸入方式,第二種最大的優點就是輸出的結果形式方便我們觀察。