Tensorflow RNN原始碼解析筆記1:RNNCell的基本實現
前言
本系列主要主要是記錄下Tensorflow在RNN實現這一塊的相關程式碼,不做詳細解釋,主要是翻譯加筆記。
RNNCell
在Tensorflow中,定義了一個RNNCell的抽象類,具體的所有不同型別的RNN Cell都是基於這個類的,所以就首先講一下這個,下面是基本的程式碼:
class RNNCell(object):
def __call__(self, inputs, state, scope=None):
raise NotImplementedError("Abstract method")
@property
def state_size (self):
raise NotImplementedError("Abstract method")
@property
def output_size(self):
raise NotImplementedError("Abstract method")
def zero_state(self, batch_size, dtype):
state_size = self.state_size
if nest.is_sequence(state_size):
state_size_flat = nest.flatten(state_size)
zeros_flat = [
array_ops.zeros(
array_ops.pack(_state_size_with_prefix(s, prefix=[batch_size])),
dtype=dtype)
for s in state_size_flat]
for s, z in zip(state_size_flat, zeros_flat):
z.set_shape(_state_size_with_prefix(s, prefix=[None]))
zeros = nest.pack_sequence_as(structure=state_size,
flat_sequence=zeros_flat)
else:
zeros_size = _state_size_with_prefix(state_size, prefix=[batch_size])
zeros = array_ops.zeros(array_ops.pack(zeros_size), dtype=dtype)
zeros.set_shape(_state_size_with_prefix(state_size, prefix=[None ]))
return zeros
在Tensorflow中,Cell的定義不同於其他資料當中的定義,在其他的文件中Cell(下文指代為L-Cell)被看做是一個能夠產生Single Scalar輸出的物件,而在這裡則是一個包含一系列L-Cell的水平陣列。
具體到RNNCell,RNNCell是一個包含一個State(狀態)並且能夠執行一些處理輸入矩陣的物件。RNNCell將輸入的矩陣(Input Matrix),運算輸出一個包含”self.output”列的輸出矩陣(Ouput Matrix)。如果定義了“self.state_size”這個屬性,並且取值為一個整數,那麼RNNCell則會同時輸出一個狀態矩陣(State Matrix),包含“self.state_size”列。而如果“self.state_size”定義為一個整數的Tuple,,那麼則是輸出對應長度的狀態矩陣的Tuple,Tuple中的每一個狀態矩陣長度還是和之前的一樣,包含“self.state_size”列。
在Tensorflow中,將會基於整個RNNCell實現一系列常用的RNNCell,比如LSTM和GRU,並且將會支援包含Dropout等在內的特性,同時也支援構建多層的RNN網路。
RNNCell基本結構
RNNCell有一些基本的屬性需要設定:
state_size: 說明這個Cell使用的State的大小
output_size: 這個RNNCell最後生成結果的大小
對於每一個RNNCell的具體實現類,都必須要實現__call__這個方法:
每一個具體的RNN類必須實現的方法:
def __call__(self, inputs, state, scope=None):
這個方法是RNNCell的核心方法,其需要的屬性有:
inputs: 這個需要輸入一個形狀為[batch_size,input_size]的2D Tensor,batch_size是你訓練中指定的batch的大小,而input_size則是輸入資料的維度
state: state就是你rnn網路中rnn cell的狀態,比如說如果你的rnn定義包含了N個單元(也就是你的self.state_size是個整數N),那麼在你每次執行RNN網路時就應該給一個[batch_size,self.state_size]形狀的2D Tensor來表示當前RNN網路的狀態,而如果你的self.state_size是一個元祖,那麼給定的狀態也應該是一個Tuple,每個Tuple裡的狀態表示和之前的方式一樣,只要注意好不同的self.state_size取值就好
而RNN Cell經過一系列的工作後,將會輸出如下的東西:
output:對應的你的batch的大小和輸出大小的結果,形狀是[batch_size x self.output_size]
state:根據你的self.state_size的不同,輸出一個更新後的RNN狀態,或者一個Tuple的狀態,格式對應輸入的state
同時RNNCell還定義了一個非抽象的方法,那就是生成初始化狀態的方法,比較簡單就不多說了:
def zero_state(self, batch_size, dtype):
BasicRNNCell
下面介紹完了RNNCell的定義,我們來看一個最原始的RNN的實現,就是不涉及到LSTM,GRU的那種。這種RNNCell被稱作BasicRNNCell,程式碼很簡短:
class BasicRNNCell(RNNCell):
"""The most basic RNN cell."""
def __init__(self, num_units, input_size=None, activation=tanh):
if input_size is not None:
logging.warn("%s: The input_size parameter is deprecated.", self)
self._num_units = num_units
self._activation = activation
@property
def state_size(self):
return self._num_units
@property
def output_size(self):
return self._num_units
def __call__(self, inputs, state, scope=None):
"""Most basic RNN: output = new_state = activation(W * input + U * state + B)."""
with vs.variable_scope(scope or type(self).__name__): # "BasicRNNCell"
output = self._activation(_linear([inputs, state], self._num_units, True))
return output, output
在最基本的RNN實現當中,RNN在時間t的輸出,就是其在時間t的狀態
output = new_state = activation(W * input + U * state + B)
這個計算就直接在__call__中計算完成了,這個函式比較簡單,但是他具體如何計算則呼叫了一個方法,不在類中,那麼我們看看這個函式先:
_linear([inputs, state], self._num_units, True)
對應函式介紹,_liner的功能就是你給了一個或一系列的Tensor(A,B,C.....),他給你計算一個W1*A+W2*B.....+Bias的結果存在,比如輸入[input,state],那麼這個方法就是計算W * input + U * state:
def _linear(args, output_size, bias, bias_start=0.0, scope=None):
"""Linear map: sum_i(args[i] * W[i]), where W[i] is a variable.
Args:
args: a 2D Tensor or a list of 2D, batch x n, Tensors.
output_size: int, second dimension of W[i].
bias: boolean, whether to add a bias term or not.
bias_start: starting value to initialize the bias; 0 by default.
scope: VariableScope for the created subgraph; defaults to "Linear".
Returns:
A 2D Tensor with shape [batch x output_size] equal to
sum_i(args[i] * W[i]), where W[i]s are newly created matrices.
Raises:
ValueError: if some of the arguments has unspecified or wrong shape.
到此,關於Tensorflow裡面RNNCell的基本結構,以及BasicRNNCell的原始碼分析結束。
以上,MebiuW