1. 程式人生 > >談談Tensorflow的Batch Normalization的使用

談談Tensorflow的Batch Normalization的使用

tensorflow 在實現Batch Normalization (各個網路層輸出的結果歸一化,以防止過擬合)時,主要用到一下兩個API。分別是

1)tf.nn.moments(x, axes, name=None, keep_dims=False) ⇒ mean, variance: 

其中計算的得到的為統計矩,mean 是一階矩,variance 是二階中心矩 各引數的另一為

  • x 可以理解為我們輸出的資料,形如 [batchsize, height, width, kernels]
  • axes 表示在哪個維度上求解,是個list,例如 [0, 1, 2]
  • name 就是個名字,
  • keep_dims 是否保持維度
下面為一個例子:
img = tf.Variable(tf.random_normal([2, 3]))
axis = list(range(len(img.get_shape()) - 1))
mean, variance = tf.nn.moments(img, axis)
輸出的結果分別為:
img = [[ 0.69495416  2.08983064 -1.08764684]
         [ 0.31431156 -0.98923939 -0.34656194]]
mean =  [ 0.50463283  0.55029559 -0.71710438]
variance =  [ 0.0362222   2.37016821  0.13730171]
這個例子挺容易理解的,該函式就是在[0] 維度上求了一個均值和方差。 2)tf.nn.batch_normalization(x, mean, variance, offset, scale, variance_epsilon, name=None)  tf.nn.batch_norm_with_global_normalization(t, m, v, beta, gamma, variance_epsilon, scale_after_normalization, name=None) 由函式介面可知,tf.nn.moments 計算返回的 mean 和 variance 作為 tf.nn.batch_normalization 引數進一步呼叫;

在這一堆引數裡面,其中x,mean和variance這三個,已經知道了,就是通過moments計算得到的,另外菱格引數,offset和scale一般需要訓練,其中offset一般初始化為0,scale初始化為1,另外這兩個引數的offset,scale的維度和mean相同。
def batch_norm(x, name_scope, training, epsilon=1e-3, decay=0.99):
    """ Assume 2d [batch, values] tensor"""
    with tf.variable_scope(name_scope):
        size = x.get_shape().as_list()[1]
        scale = tf.get_variable('scale', [size], initializer=tf.constant_initializer(0.1))
        offset = tf.get_variable('offset', [size])

        pop_mean = tf.get_variable('pop_mean', [size], initializer=tf.zeros_initializer(), trainable=False)
        pop_var = tf.get_variable('pop_var', [size], initializer=tf.ones_initializer(), trainable=False)
        batch_mean, batch_var = tf.nn.moments(x, [0])
        train_mean_op = tf.assign(pop_mean, pop_mean*decay+batch_mean*(1-decay))
        train_var_op = tf.assign(pop_var, pop_var*decay + batch_var*(1-decay))

        def batch_statistics():
            with tf.control_dependencies([train_mean_op, train_var_op]):
                return tf.nn.batch_normalization(x, batch_mean, batch_var, offset, scale, epsilon)

        def population_statistics():
            return tf.nn.batch_normalization(x, pop_mean, pop_var, offset, scale, epsilon)

        return tf.cond(training, batch_statistics, population_statistics)
參考文章: [1] https://www.jianshu.com/p/0312e04e4e83
[2] http://blog.csdn.net/lanchunhui/article/details/70792458

歡迎關注: 自然語言處理技術
這裡寫圖片描述