1. 程式人生 > >tensorflow中RNNcell原始碼分析以及自定義RNNCell的方法

tensorflow中RNNcell原始碼分析以及自定義RNNCell的方法

我們在模擬一些論文的時候經常會遇到一些模型,對RNN或者LSTM進行了少許的修改,或者自己定義了一種RNN的結構等情況,比如前面介紹的幾篇memory networks的論文,往往都需要按照自己定義的方法來構造RNN網路。所以本篇部落格就主要總結一下RNNcell的用法以及如何按照自己的需求自定義RNNCell。

tf中RNNCell的用法介紹

我們直接從原始碼的層面來看一看tf是如何實現RNNCell定義的。程式碼入下:

    class RNNCell(base_layer.Layer):

      def __call__(self, inputs, state, scope=None)
:
if scope is not None: with vs.variable_scope(scope, custom_getter=self._rnn_get_variable) as scope: return super(RNNCell, self).__call__(inputs, state, scope=scope) else: with vs.variable_scope(vs.get_variable_scope(), custom_getter=self._rnn_get_variable): return
super(RNNCell, self).__call__(inputs, state) def _rnn_get_variable(self, getter, *args, **kwargs): variable = getter(*args, **kwargs) trainable = (variable in tf_variables.trainable_variables() or (isinstance(variable, tf_variables.PartitionedVariable) and
list(variable)[0] in tf_variables.trainable_variables())) if trainable and variable not in self._trainable_weights: self._trainable_weights.append(variable) elif not trainable and variable not in self._non_trainable_weights: self._non_trainable_weights.append(variable) return variable @property def state_size(self): raise NotImplementedError("Abstract method") @property def output_size(self): raise NotImplementedError("Abstract method") def build(self, _): pass def zero_state(self, batch_size, dtype): with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): state_size = self.state_size return _zero_state_tensors(state_size, batch_size, dtype)

RNNCell是一個抽象的父類,其他的RNNcell都會繼承該方法,然後具體實現其中的call()函式。從上面的定義中我們發現其主要有state_size和output_size兩個屬性,分別代表了隱藏層和輸出層的維度。然後就是zero_state()和call()兩個函式,分別用於初始化初始狀態h0為全零向量和定義實際的RNNCell的操作(比如RNN就是一個啟用,GRU的兩個門,LSTM的三個門控等,不同的RNN的區別主要體現在這個函式)。有了這個抽象類,我們接下來看看tf中BasicRNNCell、GRUCell、BasicLSTMCell三個cell的定義方法,瞭解不同變種RNN模型的定義方式的區別和實現方法。

    class BasicRNNCell(RNNCell):

      def __init__(self, num_units, activation=None, reuse=None):
        super(BasicRNNCell, self).__init__(_reuse=reuse)
        self._num_units = num_units
        self._activation = activation or math_ops.tanh

      @property
      def state_size(self):
        return self._num_units

      @property
      def output_size(self):
        return self._num_units

      def call(self, inputs, state):
        output = self._activation(_linear([inputs, state], self._num_units, True))
        return output, output

最簡單的RNN結構如上圖所示,可以看出BasicRNNCell中把state_size和output_size定義成相同,而且ht和output也是相同的(看call函式的輸出是兩個output,也就是其並未定義輸出部分)。再看一下其主要功能實現就是call函式的第一行,就是input和前一時刻狀態state經過一個線性函式在經過一個啟用函式即可,也就是最普通的RNN定義方式。也就是說output = new_state = f(W * input + U * state + B)。接下來我們看一下GRU的定義:

    class GRUCell(RNNCell):

      def __init__(self,
                   num_units,
                   activation=None,
                   reuse=None,
                   kernel_initializer=None,
                   bias_initializer=None):
        super(GRUCell, self).__init__(_reuse=reuse)
        self._num_units = num_units
        self._activation = activation or math_ops.tanh
        self._kernel_initializer = kernel_initializer
        self._bias_initializer = bias_initializer

      @property
      def state_size(self):
        return self._num_units

      @property
      def output_size(self):
        return self._num_units

      def call(self, inputs, state):
        with vs.variable_scope("gates"):  # Reset gate and update gate.
          # We start with bias of 1.0 to not reset and not update.
          bias_ones = self._bias_initializer
          if self._bias_initializer is None:
            dtype = [a.dtype for a in [inputs, state]][0]
            bias_ones = init_ops.constant_initializer(1.0, dtype=dtype)
          value = math_ops.sigmoid(
              _linear([inputs, state], 2 * self._num_units, True, bias_ones,
                      self._kernel_initializer))
          r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1)
        with vs.variable_scope("candidate"):
          c = self._activation(
              _linear([inputs, r * state], self._num_units, True,
                      self._bias_initializer, self._kernel_initializer))
        new_h = u * state + (1 - u) * c
        return new_h, new_h

相比BasicRNNCell只改變了call函式部分,增加了重置門和更新門兩部分,分別由r和u表示。然後c表示要更新的狀態值。其對應的公式如如下所示:

    r = f(W1 * input + U1 * state + B1)

    u = f(W2 * input + U2 * state + B2)

    c = f(W3 * input + U3 * r * state + B3)

    h_new = u * h + (1 - u) * c

接下來再看一下BasicLSTMCell的實現方法,相比GRU,LSTM又多了一個輸出門,而且又新增添了一個C表示其內部狀態,然後將h和c以tuple的形式返回作為LSTM內部的狀態變數。其內部結構和公式表示如下圖所示:

    class BasicLSTMCell(RNNCell):

      def __init__(self, num_units, forget_bias=1.0,
                   state_is_tuple=True, activation=None, reuse=None):

        super(BasicLSTMCell, self).__init__(_reuse=reuse)
        if not state_is_tuple:
          logging.warn("%s: Using a concatenated state is slower and will soon be "
                       "deprecated.  Use state_is_tuple=True.", self)
        self._num_units = num_units
        self._forget_bias = forget_bias
        self._state_is_tuple = state_is_tuple
        self._activation = activation or math_ops.tanh

      @property
      def state_size(self):
        return (LSTMStateTuple(self._num_units, self._num_units)
                if self._state_is_tuple else 2 * self._num_units)

      @property
      def output_size(self):
        return self._num_units

      def call(self, inputs, state):
        sigmoid = math_ops.sigmoid
        # Parameters of gates are concatenated into one multiply for efficiency.
        if self._state_is_tuple:
          c, h = state
        else:
          c, h = array_ops.split(value=state, num_or_size_splits=2, axis=1)

        concat = _linear([inputs, h], 4 * self._num_units, True)

        # i = input_gate, j = new_input, f = forget_gate, o = output_gate
        i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)

        new_c = (
            c * sigmoid(f + self._forget_bias) + sigmoid(i) * self._activation(j))
        new_h = self._activation(new_c) * sigmoid(o)

        if self._state_is_tuple:
          new_state = LSTMStateTuple(new_c, new_h)
        else:
          new_state = array_ops.concat([new_c, new_h], 1)
        return new_h, new_state

從上面的程式碼可以發現,其與BasicRNNCell和GRUCell的區別也主要在call()函式上,不同的功能實現也都在call裡面進行。不難發現,還有一個在三個累裡面都頻繁使用到的函式_linear(),這個函式的作用是什麼呢,我們先來看一下其原始碼:

    def _linear(args,
                output_size,
                bias,
                bias_initializer=None,
                kernel_initializer=None):

      if args is None or (nest.is_sequence(args) and not args):
        raise ValueError("`args` must be specified")
      if not nest.is_sequence(args):
        args = [args]

      # Calculate the total size of arguments on dimension 1.
      total_arg_size = 0
      shapes = [a.get_shape() for a in args]
      for shape in shapes:
        if shape.ndims != 2:
          raise ValueError("linear is expecting 2D arguments: %s" % shapes)
        if shape[1].value is None:
          raise ValueError("linear expects shape[1] to be provided for shape %s, "
                           "but saw %s" % (shape, shape[1]))
        else:
          total_arg_size += shape[1].value

      dtype = [a.dtype for a in args][0]

      # Now the computation.
      scope = vs.get_variable_scope()
      with vs.variable_scope(scope) as outer_scope:
        weights = vs.get_variable(
            _WEIGHTS_VARIABLE_NAME, [total_arg_size, output_size],
            dtype=dtype,
            initializer=kernel_initializer)
        if len(args) == 1:
          res = math_ops.matmul(args[0], weights)
        else:
          res = math_ops.matmul(array_ops.concat(args, 1), weights)
        if not bias:
          return res
        with vs.variable_scope(outer_scope) as inner_scope:
          inner_scope.set_partitioner(None)
          if bias_initializer is None:
            bias_initializer = init_ops.constant_initializer(0.0, dtype=dtype)
          biases = vs.get_variable(
              _BIAS_VARIABLE_NAME, [output_size],
              dtype=dtype,
              initializer=bias_initializer)
        return nn_ops.bias_add(res, biases)

這個函式的輸入args就是[input, state],而output_size是輸出層的大小,我們可以看到BasicRNNCell中,output_size就是_num_units,GRUCell中是2*_num_units,BasicLSTMCell中是4*_num_units,這是因為_linear中執行的是RNN中的幾個等式的Wx+Uh+B的功能,但是不同的RNN中數量不同,比如LSTM中需要計算四次,然後直接把output_size定義為4*_num_units,再把輸出進行拆分成四個變數即可~~

到這裡我們就簡單分析了一下tensorflow中不同RNN的實現方法,接下來我們就要看一看如何實現自己模型中所需要的RNNCell。

tf中自定義RNNCell的方法

Recurrent Entity Networks

看完GRU和LSTM cell的實現方案,我覺得應該不難想象出自定義RNNCell的方法,那就是繼承RNNCell這個抽象類,然後實現init、state_size、output_size、call四個函式就行了,其中在call函式中實現自己需要的功能即可。我們來結合之前模擬過得Recurrent Entity Networks這篇文章中使用的帶來來介紹一下,該模型每個cell中包含m個slot,也就是m個記憶,每個記憶都是一個mem_sz維的向量,然後每個slot都有一個key,用來做索引。其公式如下所示:

    class DynamicMemory(tf.contrib.rnn.RNNCell):
        def __init__(self, memory_slots, memory_size, keys, activation=prelu,
                     initializer=tf.random_normal_initializer(stddev=0.1)):
            """
            Instantiate a DynamicMemory Cell, with the given number of memory slots, and key vectors.
            :param memory_slots: Number of memory slots to initialize. 
            :param memory_size: Dimensionality of memories => tied to embedding size. 
            :param keys: List of keys to seed the Dynamic Memory with (can be random).
            :param initializer: Variable Initializer for Cell Parameters.
            """ 
            self.m, self.mem_sz, self.keys = memory_slots, memory_size, keys
            self.activation, self.init = activation, initializer

            # 公式2中的三個引數,在所有RNN Cell中共享。
            self.U = tf.get_variable("U", [self.mem_sz, self.mem_sz], initializer=self.init)
            self.V = tf.get_variable("V", [self.mem_sz, self.mem_sz], initializer=self.init)
            self.W = tf.get_variable("W", [self.mem_sz, self.mem_sz], initializer=self.init)

        @property
        def state_size(self):
            return [self.mem_sz for _ in range(self.m)]

        @property
        def output_size(self):
            return [self.mem_sz for _ in range(self.m)]

        def zero_state(self, batch_size, dtype):
            return [tf.tile(tf.expand_dims(key, 0), [batch_size, 1]) for key in self.keys]

        def __call__(self, inputs, state, scope=None):
            """
            Run the Dynamic Memory Cell on the inputs, updating the memories with each new time step.
            :param inputs: 2D Tensor of shape [bsz, mem_sz] representing a story sentence.
            :param states: List of length M, each with 2D Tensor [bsz, mem_sz] => h_j (starts as key).
            """
            new_states = []
            #下面的迴圈表示m個memory slot,對每個slot都執行相同的操作
            for block_id, h in enumerate(state):
                # 下面三行主要實現公式1,即門函式g的計算
                content_g = tf.reduce_sum(tf.multiply(inputs, h), axis=[1])                  # Shape: [bsz]
                address_g = tf.reduce_sum(tf.multiply(inputs, 
                                          tf.expand_dims(self.keys[block_id], 0)), axis=[1]) # Shape: [bsz]
                g = sigmoid(content_g + address_g)

                #下面四行主要是公式2的計算,根據輸入s和記憶h得到新的記憶h_
                h_component = tf.matmul(h, self.U)                                           # Shape: [bsz, mem_sz]
                w_component = tf.matmul(tf.expand_dims(self.keys[block_id], 0), self.V)      # Shape: [1, mem_sz]
                s_component = tf.matmul(inputs, self.W)                                      # Shape: [bsz, mem_sz]
                candidate = self.activation(h_component + w_component + s_component)         # Shape: [bsz, mem_sz]

                # 將新的記憶h_與門空函式g相乘之後的結果加到原始的記憶h中
                new_h = h + tf.multiply(tf.expand_dims(g, -1), candidate)                    # Shape: [bsz, mem_sz]

                #對記憶h進行歸一化
                new_h_norm = tf.nn.l2_normalize(new_h, -1)                                   # Shape: [bsz, mem_sz]
                new_states.append(new_h_norm)

            return new_states, new_states

上面這種方式定義的cell,直接呼叫tf.nn.dynamic_rnn()函式就可以進行unrolling來構建模型了。

Neural Turing Machines

除此之外,我們還可以完全自定義cell,不繼承RNNCell,我們可以先來看一下官網給出的RNNCell的定義,其實只要求有一個call函式即可。


    Every RNNCell must have the properties below and implement call with the signature 

    (output, next_state) = call(input, state). 

    The optional third input argument, scope, is allowed for backwards compatibility purposes; 

    but should be left off for new subclasses.

有的時候我們可能會有更多的需求,這是我們可以不繼承RNNCell,直接定義一個類即可,不過有的時候就無法呼叫tf.rnn.dynamic_rnn函式來進行自動化建模,而需要自己寫函式進行迴圈呼叫從而實現unrolling的效果。我們可以結合ntm的程式碼進行介紹。

    cell = ntm_cell.NTMCell(args.rnn_size, args.memory_size, args.memory_vector_dim, 1, 1,
                            addressing_mode='content_and_location',
                            reuse=reuse,
                            output_dim=args.vector_dim)

    state = cell.zero_state(args.batch_size, tf.float32)
    self.state_list = [state]
    for t in range(seq_length):
        output, state = cell(tf.concat([self.x[:, t, :], np.zeros([args.batch_size, 1])], axis=1), state)
        self.state_list.append(state)
    output, state = cell(eof, state)
    self.state_list.append(state)

上面這幾行程式碼是先建立NTMCell的物件,然後接下來初始化全零狀態,再就是迴圈呼叫cell的call函式,並將中間的state儲存下來即可。NTMCell的定義方式如下所示,不需要繼承RNNCell,而是全部自定義的方法來實現。

    class NTMCell():
        def __init__(self, rnn_size, memory_size, memory_vector_dim, read_head_num, write_head_num,
                     addressing_mode='content_and_loaction', shift_range=1, reuse=False, output_dim=None):
            self.rnn_size = rnn_size
            self.memory_size = memory_size
            self.memory_vector_dim = memory_vector_dim
            self.read_head_num = read_head_num
            self.write_head_num = write_head_num
            self.addressing_mode = addressing_mode
            self.reuse = reuse
            self.controller = tf.nn.rnn_cell.BasicRNNCell(self.rnn_size)
            self.step = 0
            self.output_dim = output_dim
            self.shift_range = shift_range

        def __call__(self, x, prev_state):
            prev_read_vector_list = prev_state['read_vector_list']      # read vector in Sec 3.1 (the content that is
                                                                        # read out, length = memory_vector_dim)
            prev_controller_state = prev_state['controller_state']      # state of controller (LSTM hidden state)

            # x + prev_read_vector -> controller (RNN) -> controller_output
            controller_input = tf.concat([x] + prev_read_vector_list, axis=1)
            with tf.variable_scope('controller', reuse=self.reuse):
                controller_output, controller_state = self.controller(controller_input, prev_controller_state)

            num_parameters_per_head = self.memory_vector_dim + 1 + 1 + (self.shift_range * 2 + 1) + 1
            num_heads = self.read_head_num + self.write_head_num
            total_parameter_num = num_parameters_per_head * num_heads + self.memory_vector_dim * 2 * self.write_head_num
            with tf.variable_scope("o2p", reuse=(self.step > 0) or self.reuse):
                o2p_w = tf.get_variable('o2p_w', [controller_output.get_shape()[1], total_parameter_num],
                                        initializer=tf.random_normal_initializer(mean=0.0, stddev=0.5))
                o2p_b = tf.get_variable('o2p_b', [total_parameter_num],
                                        initializer=tf.random_normal_initializer(mean=0.0, stddev=0.5))
                parameters = tf.nn.xw_plus_b(controller_output, o2p_w, o2p_b)
            head_parameter_list = tf.split(parameters[:, :num_parameters_per_head * num_heads], num_heads, axis=1)
            erase_add_list = tf.split(parameters[:, num_parameters_per_head * num_heads:], 2 * self.write_head_num, axis=1)

            # k, beta, g, s, gamma -> w

            prev_w_list = prev_state['w_list']  # vector of weightings (blurred address) over locations
            prev_M = prev_state['M']
            w_list = []
            p_list = []
            for i, head_parameter in enumerate(head_parameter_list):

                # Some functions to constrain the result in specific range
                # exp(x)                -> x > 0
                # sigmoid(x)            -> x \in (0, 1)
                # softmax(x)            -> sum_i x_i = 1
                # log(exp(x) + 1) + 1   -> x > 1

                k = tf.tanh(head_parameter[:, 0:self.memory_vector_dim])
                beta = tf.sigmoid(head_parameter[:, self.memory_vector_dim]) * 10        # do not use exp, it will explode!
                g = tf.sigmoid(head_parameter[:, self.memory_vector_dim + 1])
                s = tf.nn.softmax(
                    head_parameter[:, self.memory_vector_dim + 2:self.memory_vector_dim + 2 + (self.shift_range * 2 + 1)]
                )
                gamma = tf.log(tf.exp(head_parameter[:, -1]) + 1) + 1
                with tf.variable_scope('addressing_head_%d' % i):
                    w = self.addressing(k, beta, g, s, gamma, prev_M, prev_w_list[i])     # Figure 2
                w_list.append(w)
                p_list.append({'k': k, 'beta': beta, 'g': g, 's': s, 'gamma': gamma})

            # Reading (Sec 3.1)

            read_w_list = w_list[:self.read_head_num]
            read_vector_list = []
            for i in range(self.read_head_num):
                read_vector = tf.reduce_sum(tf.expand_dims(read_w_list[i], dim=2) * prev_M, axis=1)
                read_vector_list.append(read_vector)

            # Writing (Sec 3.2)

            write_w_list = w_list[self.read_head_num:]
            M = prev_M
            for i in range(self.write_head_num):
                w = tf.expand_dims(write_w_list[i], axis=2)
                erase_vector = tf.expand_dims(tf.sigmoid(erase_add_list[i * 2]), axis=1)
                add_vector = tf.expand_dims(tf.tanh(erase_add_list[i * 2 + 1]), axis=1)
                M = M * (tf.ones(M.get_shape()) - tf.matmul(w, erase_vector)) + tf.matmul(w, add_vector)

            # controller_output -> NTM output

            if not self.output_dim:
                output_dim = x.get_shape()[1]
            else:
                output_dim = self.output_dim
            with tf.variable_scope("o2o", reuse=(self.step > 0) or self.reuse):
                o2o_w = tf.get_variable('o2o_w', [controller_output.get_shape()[1], output_dim],
                                        initializer=tf.random_normal_initializer(mean=0.0, stddev=0.5))
                o2o_b = tf.get_variable('o2o_b', [output_dim],
                                        initializer=tf.random_normal_initializer(mean=0.0, stddev=0.5))
                NTM_output = tf.nn.xw_plus_b(controller_output, o2o_w, o2o_b)

            state = {
                'controller_state': controller_state,
                'read_vector_list': read_vector_list,
                'w_list': w_list,
                'p_list': p_list,
                'M': M
            }

            self.step += 1
            return NTM_output, state

        def addressing(self, k, beta, g, s, gamma, prev_M, prev_w):

            # Sec 3.3.1 Focusing by Content

            # Cosine Similarity

            k = tf.expand_dims(k, axis=2)
            inner_product = tf.matmul(prev_M, k)
            k_norm = tf.sqrt(tf.reduce_sum(tf.square(k), axis=1, keep_dims=True))
            M_norm = tf.sqrt(tf.reduce_sum(tf.square(prev_M), axis=2, keep_dims=True))
            norm_product = M_norm * k_norm
            K = tf.squeeze(inner_product / (norm_product + 1e-8))                   # eq (6)

            # Calculating w^c

            K_amplified = tf.exp(tf.expand_dims(beta, axis=1) * K)
            w_c = K_amplified / tf.reduce_sum(K_amplified, axis=1, keep_dims=True)  # eq (5)

            if self.addressing_mode == 'content':                                   # Only focus on content
                return w_c

            # Sec 3.3.2 Focusing by Location

            g = tf.expand_dims(g, axis=1)
            w_g = g * w_c + (1 - g) * prev_w                                        # eq (7)

            s = tf.concat([s[:, :self.shift_range + 1],
                           tf.zeros([s.get_shape()[0], self.memory_size - (self.shift_range * 2 + 1)]),
                           s[:, -self.shift_range:]], axis=1)
            t = tf.concat([tf.reverse(s, axis=[1]), tf.reverse(s, axis=[1])], axis=1)
            s_matrix = tf.stack(
                [t[:, self.memory_size - i - 1:self.memory_size * 2 - i - 1] for i in range(self.memory_size)],
                axis=1
            )
            w_ = tf.reduce_sum(tf.expand_dims(w_g, axis=1) * s_matrix, axis=2)      # eq (8)
            w_sharpen = tf.pow(w_, tf.expand_dims(gamma, axis=1))
            w = w_sharpen / tf.reduce_sum(w_sharpen, axis=1, keep_dims=True)        # eq (9)

            return w

        def zero_state(self, batch_size, dtype):
            def expand(x, dim, N):
                return tf.concat([tf.expand_dims(x, dim) for _ in range(N)], axis=dim)

            with tf.variable_scope('init', reuse=self.reuse):
                state = {
                    'controller_state': expand(tf.tanh(tf.get_variable('init_state', self.rnn_size,  initializer=tf.random_normal_initializer(mean=0.0, stddev=0.5))), dim=0, N=batch_size),
                    'read_vector_list': [expand(tf.nn.softmax(tf.get_variable('init_r_%d' % i, [self.memory_vector_dim], initializer=tf.random_normal_initializer(mean=0.0, stddev=0.5))), dim=0, N=batch_size) for i in range(self.read_head_num)],
                    'w_list': [expand(tf.nn.softmax(tf.get_variable('init_w_%d' % i, [self.memory_size], initializer=tf.random_normal_initializer(mean=0.0, stddev=0.5))), dim=0, N=batch_size) if self.addressing_mode == 'content_and_loaction' else tf.zeros([batch_size, self.memory_size]) for i in range(self.read_head_num + self.write_head_num)],
                    'M': expand(tf.tanh(tf.get_variable('init_M', [self.memory_size, self.memory_vector_dim], initializer=tf.random_normal_initializer(mean=0.0, stddev=0.5))), dim=0, N=batch_size)
                }
                return state

至此我們就結合兩個例項分析了一下在tensorflow中自定義RNNCell的兩種方法,希望對大家在使用tf程式設計的時候有所幫助~~