1. 程式人生 > >Tensorflow RNN原始碼解析筆記1:RNNCell的基本實現

Tensorflow RNN原始碼解析筆記1:RNNCell的基本實現

前言

本系列主要主要是記錄下Tensorflow在RNN實現這一塊的相關程式碼,不做詳細解釋,主要是翻譯加筆記。

RNNCell

在Tensorflow中,定義了一個RNNCell的抽象類,具體的所有不同型別的RNN Cell都是基於這個類的,所以就首先講一下這個,下面是基本的程式碼:

class RNNCell(object):
  def __call__(self, inputs, state, scope=None):
    raise NotImplementedError("Abstract method")

  @property
  def state_size
(self):
raise NotImplementedError("Abstract method") @property def output_size(self): raise NotImplementedError("Abstract method") def zero_state(self, batch_size, dtype): state_size = self.state_size if nest.is_sequence(state_size): state_size_flat = nest.flatten(state_size) zeros_flat = [ array_ops.zeros( array_ops.pack(_state_size_with_prefix(s, prefix=[batch_size])), dtype=dtype) for
s in state_size_flat] for s, z in zip(state_size_flat, zeros_flat): z.set_shape(_state_size_with_prefix(s, prefix=[None])) zeros = nest.pack_sequence_as(structure=state_size, flat_sequence=zeros_flat) else: zeros_size = _state_size_with_prefix(state_size, prefix=[batch_size]) zeros = array_ops.zeros(array_ops.pack(zeros_size), dtype=dtype) zeros.set_shape(_state_size_with_prefix(state_size, prefix=[None
])) return zeros

在Tensorflow中,Cell的定義不同於其他資料當中的定義,在其他的文件中Cell(下文指代為L-Cell)被看做是一個能夠產生Single Scalar輸出的物件,而在這裡則是一個包含一系列L-Cell的水平陣列。

具體到RNNCell,RNNCell是一個包含一個State(狀態)並且能夠執行一些處理輸入矩陣的物件。RNNCell將輸入的矩陣(Input Matrix),運算輸出一個包含”self.output”列的輸出矩陣(Ouput Matrix)。如果定義了“self.state_size”這個屬性,並且取值為一個整數,那麼RNNCell則會同時輸出一個狀態矩陣(State Matrix),包含“self.state_size”列。而如果“self.state_size”定義為一個整數的Tuple,,那麼則是輸出對應長度的狀態矩陣的Tuple,Tuple中的每一個狀態矩陣長度還是和之前的一樣,包含“self.state_size”列。

在Tensorflow中,將會基於整個RNNCell實現一系列常用的RNNCell,比如LSTM和GRU,並且將會支援包含Dropout等在內的特性,同時也支援構建多層的RNN網路。

RNNCell基本結構

RNNCell有一些基本的屬性需要設定:

state_size: 說明這個Cell使用的State的大小
output_size: 這個RNNCell最後生成結果的大小

對於每一個RNNCell的具體實現類,都必須要實現__call__這個方法:

每一個具體的RNN類必須實現的方法:
def __call__(self, inputs, state, scope=None):

這個方法是RNNCell的核心方法,其需要的屬性有:

inputs: 這個需要輸入一個形狀為[batch_size,input_size]的2D Tensor,batch_size是你訓練中指定的batch的大小,而input_size則是輸入資料的維度

state: state就是你rnn網路中rnn cell的狀態,比如說如果你的rnn定義包含了N個單元(也就是你的self.state_size是個整數N),那麼在你每次執行RNN網路時就應該給一個[batch_size,self.state_size]形狀的2D Tensor來表示當前RNN網路的狀態,而如果你的self.state_size是一個元祖,那麼給定的狀態也應該是一個Tuple,每個Tuple裡的狀態表示和之前的方式一樣,只要注意好不同的self.state_size取值就好

而RNN Cell經過一系列的工作後,將會輸出如下的東西:

output:對應的你的batch的大小和輸出大小的結果,形狀是[batch_size x self.output_size]

state:根據你的self.state_size的不同,輸出一個更新後的RNN狀態,或者一個Tuple的狀態,格式對應輸入的state

同時RNNCell還定義了一個非抽象的方法,那就是生成初始化狀態的方法,比較簡單就不多說了:

 def zero_state(self, batch_size, dtype):

BasicRNNCell

這裡寫圖片描述
下面介紹完了RNNCell的定義,我們來看一個最原始的RNN的實現,就是不涉及到LSTM,GRU的那種。這種RNNCell被稱作BasicRNNCell,程式碼很簡短:

class BasicRNNCell(RNNCell):
  """The most basic RNN cell."""

  def __init__(self, num_units, input_size=None, activation=tanh):
    if input_size is not None:
      logging.warn("%s: The input_size parameter is deprecated.", self)
    self._num_units = num_units
    self._activation = activation

  @property
  def state_size(self):
    return self._num_units

  @property
  def output_size(self):
    return self._num_units

  def __call__(self, inputs, state, scope=None):
    """Most basic RNN: output = new_state = activation(W * input + U * state + B)."""
    with vs.variable_scope(scope or type(self).__name__):  # "BasicRNNCell"
      output = self._activation(_linear([inputs, state], self._num_units, True))
    return output, output

在最基本的RNN實現當中,RNN在時間t的輸出,就是其在時間t的狀態

output = new_state = activation(W * input + U * state + B)

這個計算就直接在__call__中計算完成了,這個函式比較簡單,但是他具體如何計算則呼叫了一個方法,不在類中,那麼我們看看這個函式先:

_linear([inputs, state], self._num_units, True)

對應函式介紹,_liner的功能就是你給了一個或一系列的Tensor(A,B,C.....),他給你計算一個W1*A+W2*B.....+Bias的結果存在,比如輸入[input,state],那麼這個方法就是計算W * input + U * state:
def _linear(args, output_size, bias, bias_start=0.0, scope=None):
  """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable.

  Args:
    args: a 2D Tensor or a list of 2D, batch x n, Tensors.
    output_size: int, second dimension of W[i].
    bias: boolean, whether to add a bias term or not.
    bias_start: starting value to initialize the bias; 0 by default.
    scope: VariableScope for the created subgraph; defaults to "Linear".

  Returns:
    A 2D Tensor with shape [batch x output_size] equal to
    sum_i(args[i] * W[i]), where W[i]s are newly created matrices.

  Raises:
    ValueError: if some of the arguments has unspecified or wrong shape.

到此,關於Tensorflow裡面RNNCell的基本結構,以及BasicRNNCell的原始碼分析結束。
以上,MebiuW