1. 程式人生 > >Batch normalization及其在tensorflow中的實現

Batch normalization及其在tensorflow中的實現

Batch normalization(BN)

BN是對輸入的特徵圖進行標準化的操作,其公式為:

 

  • xx - 原輸入
  • x^x^ - 標準化後的輸入
  • μμ - 一個batch中的均值
  • σ2σ2 - 一個batch中的方差
  • ϵϵ - 一個很小的數,防止除0
  • ββ - 中心偏移量(center)
  • γγ - 縮放(scale)係數

tensorflow中提供了三種BN方法:

  • tf.nn.batch_normalization
  • tf.layers.batch_normalization
  • tf.contrib.layers.batch_norm

tf.layers.batch_normalization為例介紹裡面所包含的主要引數:

tf.layers.batch_normalization(inputs, decay=0.999, center=True, scale=True, is_training=True, epsilon=0.001)
  • 1

一般使用只要定義以下的引數即可:

  • inputs: 輸入張量[N, H, W, C]

  • decay: 滑動平均的衰減係數,一般取接近1的值,這樣能在驗證和測試集上獲得較好結果

  • center: 中心偏移量,上述的ββ ,為True,則自動新增,否則忽略

  • scale: 縮放係數,上述的γγ,為True,則自動新增,否則忽略

  • epsilon: 為防止除0而加的一個很小的數

  • is_training: 是否是訓練過程,為True則代表是訓練過程,那麼將根據decay用指數滑動平均求得moments,並累加儲存到moving_meanmoving_variance中。否則是測試過程,函式直接取這兩個引數來用。

    如果是True,則需在訓練的session中新增將BN引數更新操作加入訓練的程式碼:

    
    # execute update_ops to update batch_norm weights
    
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
     optimizer = tf.train.AdamOptimizer(decayed_learning_rate)
     train_op = optimizer.minimize(loss, global_step = global_step)
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

Note

需要看上述函式的詳細引數,可在python終端通過以下命令獲取:

import tensorflow as tf
help(tf.layers.batch_normalization) # help中新增函式名

參考:https://blog.csdn.net/Leo_Xu06/article/details/79054326