1. 程式人生 > >迴圈神經網路系列(五)Tensorflow中BasicLSTMCell

迴圈神經網路系列(五)Tensorflow中BasicLSTMCell


1.結論

照慣例,先上結論,再說過程,不想看過程的可直接略過。

從這個圖我們可以知道,一個LSTM cell中有4個引數,並且形狀都是一樣的shape=[output_size+n,output_size],其中n表示輸入張量的維度,output_size通過函式BasicLSTMCell(num_units=output_size)獲得。

2.怎麼來的?

讓我們一步一步從Tensorflow的原始碼中來獲得這些資訊!

2.1 cell.state_size

首先,需要明白Tensorflow中,state表示的是cell中有幾個狀態。例如在BasicRNNCell

中,state就只有h這一個狀態;而在BasicLSTMCell中,state就有h和c這兩個狀態。其次,state_size表示的是每個狀態的第二維度,也就是output_size。

舉例:

import tensorflow as tf

output_size = 10
batch_size = 32
dim = 50
cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=output_size)
print(cell.state_size)

>>
LSTMStateTuple(c=10, h=10) 

LSTMStateTuple(c=10, h=10)

就表示,c和h的output_size都為10,即[batch_size,10]。另外Tensorflow在實現的時候,都將c,h困在一起了,即以Tuple的方式,這也是Tensorflow所推薦的。

2.2 cell.zero_state

在LSTM中,zero_state就自然對應兩個部分了, h 0 , c

0 h_0,c_0

import tensorflow as tf

output_size = 10
batch_size = 32
dim = 50
cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=output_size)
input = tf.placeholder(dtype=tf.float32, shape=[batch_size, 50])
h0 = cell.zero_state(batch_size=batch_size, dtype=tf.float32)
print(h0)

>>
LSTMStateTuple(c=<tf.Tensor 'BasicLSTMCellZeroState/zeros:0' shape=(32, 10) dtype=float32>,
h=<tf.Tensor 'BasicLSTMCellZeroState/zeros_1:0' shape=(32, 10) dtype=float32>)

可以看到,返回了c,h兩個零狀態!

2.3 關鍵性的一步cell.call

先說明一點,由於原始碼很多,所有在下面的講解中只會列出函式名和對應核心的程式碼,具體可以跳轉至call的實現部分。

import tensorflow as tf

output_size = 10
batch_size = 32
dim = 50
cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=output_size)
input = tf.placeholder(dtype=tf.float32, shape=[batch_size, 50])
h0 = cell.zero_state(batch_size=batch_size, dtype=tf.float32)
new_h, new_state = cell.call(input, h0)


------------------------------------------------------------------


def call(self, inputs, state):
    if self._state_is_tuple:
        c, h = state
    concat = _linear([inputs, h], 4 * self._num_units, True)
    # i = input_gate, j = new_input, f = forget_gate, o = output_gate
    i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)
    new_c = (
            c * sigmoid(f + self._forget_bias) + sigmoid(i) * self._activation(j))
    new_h = self._activation(new_c) * sigmoid(o)
    if self._state_is_tuple:
        new_state = LSTMStateTuple(new_c, new_h)
    else:
        new_state = array_ops.concat([new_c, new_h], 1)
    return new_h, new_state

首先,從call()中的第17行程式碼可知,首先從state中得到傳進來的c,h,然後就開始進行第18行的操作了,為了弄清楚裡面到底幹了啥,於是我們再次進入_linear()這個函式。

總的來說,_linear()這個函式實現了:先將 h , x h,x 進行concat,然後進行線性變換;如下圖所示:

而與上面兩個步驟對應程式碼就是下面的

concat = _linear([inputs, h], 4 * self._num_units, True)

def _linear(args,
            output_size,
            bias, # 是否要新增偏置
            bias_initializer=None,
            kernel_initializer=None):
            
    shapes = [a.get_shape() for a in args]# 得到傳進來的inputs,h的shape
    total_arg_size += shape[1].value # 得到inputs和h一共有多少列
    res = math_ops.matmul(array_ops.concat(args, 1), weights)# 同時計算得到4個部分的線性對映

其中, w i , w j , w f , w o w_i,w_j,w_f,w_o 就是LSTM cell中的4個權重引數。

現在我們已經清楚了_linear()的作用,接下來我們就再次回到 c a l l ( ) call() 這個函式中:

def call(self, inputs, state):
    if self._state_is_tuple:
        c, h = state
    concat = _linear([inputs, h], 4 * self._num_units, True)
    # i = input_gate, j = new_input, f = forget_gate, o = output_gate
    i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)
    new_c = (
            c * sigmoid(f + self._forget_bias) + sigmoid(i) * self._activation(j))
    new_h = self._activation(new_c) * sigmoid(o)
    if self._state_is_tuple:
        new_state = LSTMStateTuple(new_c, new_h)
    else:
        new_state = array_ops.concat([new_c, new_h], 1)
    return new_h, new_state

經過上面第4行程式碼,我們得到了4個部分線性變換後的結果;然後根據第6行程式碼可以看出,將concat又分割成了四個部分,而這四個部分對應關係如下圖所示:

接著第7,9行程式碼就分別計算出了new_c,new_h

3. 總結

import tensorflow as tf

output_size = 10
batch_size = 32
dim = 50
cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=output_size)
input = tf.placeholder(dtype=tf.float32, shape=[batch_size, 50])
h0 = cell.zero_state(batch_size=batch_size, dtype=tf.float32)
new_h, new_state = cell.call(input, h0)


對於一個基本的LSTM cell,在定義的時候,通過上面第6行程式碼的num_units,第8行的batch_size和第7行的input我們就得到了cell中所有引數的shape。

對於LSTM按時間維度展開的方法和多層堆疊同BasicRNNCell一致,在此就不贅述,可參見本系列文章的(二)(三)。


迴圈神經網路系列博文

迴圈神經網路系列(一)Tensorflow中BasicRNNCell
迴圈神經網路系列(二)Tensorflow中dynamic_rnn
迴圈神經網路系列(三)Tensorflow中MultiRNNCell
迴圈神經網路系列(四)基於LSTM的MNIST手寫體識別
迴圈神經網路系列(五)Tensorflow中BasicLSTMCell