1. 程式人生 > >從一句程式碼看tf.scan

從一句程式碼看tf.scan

在讀這篇文章的時候遇到了以下程式碼:

cell = tf.nn.rnn_cell.LSTMCell(state_size, state_is_tuple=True)
cell = tf.nn.rnn_cell.MultiRNNCell([cell] * num_layers, state_is_tuple=True)
init_state = cell.zero_state(batch_size, tf.float32)

rnn_outputs, final_states = \
       tf.scan(lambda a, x: cell(x, a[1]),
               tf.transpose(rnn_inputs, [1,0,2]),
               initializer=(tf.zeros([batch_size, state_size]), init_state))

這裡來解釋一下:

首先,tf.scan 第一個輸入是函式,也就是:

tf.scan(lambda a, x: cell(x, a[1])

等價於(未驗證,僅作illustration用):

def func(a, x):
    return cell(x, a[1])

x是輸入,a是上一步函式func的輸出。為什麼輸入cell的是a[1]呢?這是因為,根據官方文件,MultiRNNCell的輸出是:

Returns:

A pair containing:

Output: A 2-D tensor with shape [batch_size, self.output_size].
New state: Either a single 2-D tensor, or a tuple of tensors matching the arity and shapes of state.

換句話說,就是: (output,New state),也就是a。

那麼:

a = (output, new_state)
a[0] = output
a[1] = new_state

所以,cell的輸入,其一是x,也就是每一個time step的輸入,其二是a[1],也就是上一個time step 輸出的hidden state。

然後,tf.scan 的第二個輸入是input,這個沒什麼好說的,需要注意資料的形狀要從[batch_size,num_steps, state_size] 調整為[num_steps, batch_size, state_size]。tf.scan 會一步一步的把input輸入cell,每次的形狀是:[batch_size, state_size]

tf.scan第三個引數是a的初始化,那麼水到渠成,它分別初始化了output和new_state:

initializer=(tf.zeros([batch_size, state_size]), init_state))

至此這句程式碼就分析完畢了。有不懂的同學還請細細鑽研,弄懂了就不難。另附例子如下

def testScan_SingleInputMultiOutput(self):
  with self.test_session() as sess:
    elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
    initializer = (np.array(1.0), np.array(-1.0))
    r = tf.scan(lambda a, x: (a[0] * x, -a[1] * x), elems, initializer)
    r_value = sess.run(r)
 
    self.assertAllEqual([1.0, 2.0, 6.0, 24.0, 120.0, 720.0], r_value[0])
    self.assertAllEqual([1.0, -2.0, 6.0, -24.0, 120.0, -720.0], r_value[1])