1. 程式人生 > >tf.train.XXX與train有關的函式

tf.train.XXX與train有關的函式

tf.train.XXX與train有關的函式


tf.train.get_or_create_global_step()

  • 這個函式主要用於返回或者建立(如果有必要的話)一個全域性步數的tensor。引數只有一個,就是圖,如果沒有指定那麼就是預設的圖。

tf.trainable_variables()

  • 返回所有trainable=True的變數。
  • 當我們在宣告變數Variable()時傳入trainable=TrueVariable()建構函式會自動新增新的變數到圖中的集合GraphKeys.TRAINABLE_VARIABLES
    ,這個函式實質上就是返回這個集合中的變數。

tensorflow.python.training.moving_averages.assign_moving_average

這個函式的引數如下:

def assign_moving_average(variable, value, decay, zero_debias=True, name=None):

對於variable的滑動平均更新為:

v a r i a b l e = v a r i
a b l e d e c a y + v a l u e ( 1 d e c a y )

下面是一個簡單的例子(可以看出variable是變數,而value是常量),這個函式主要應用於batch_normalization

def testAssignMovingAverage(self):
  with self.test_session():
    var = tf.Variable([10.0, 11.0])
    val = tf.constant([1.0, 2.0], tf.float32)
    decay = 0.25
    assign = moving_averages.assign_moving_average(var, val, decay)
    tf.global_variables_initializer().run()
    self.assertAllClose([10.0, 11.0], var.eval())
    assign.op.run()
    self.assertAllClose([10.0 * 0.25 + 1.0 * (1.0 - 0.25),
                         11.0 * 0.25 + 2.0 * (1.0 - 0.25)],
                        var.eval())