1. 程式人生 > >【tensorflow】tensorflow中的全域性變數GLOBAL_VARIABLES及區域性變數LOCAL_VARIABLES

【tensorflow】tensorflow中的全域性變數GLOBAL_VARIABLES及區域性變數LOCAL_VARIABLES

在初學tensorflow的時候,我們會發現在函式體內定義tf.variable()或者tf.get_variable()變數的時候,跟其他語言不同,在tensorflow的函式體內定義的變數並不會隨著函式的執行結束而消失。這是因為tensorflow設定的全域性變數及區域性變數與其他語言有著本質的不同,這是因為tf裡面是由圖定義的,內部的變數區分為

tf.GraphKeys.GLOBAL_VARIABLES               #=> 'variables'                                                                                                                                                                                 
tf.GraphKeys.LOCAL_VARIABLES #=> 'local_variables' tf.GraphKeys.MODEL_VARIABLES #=> 'model_variables'
tf.GraphKeys.TRAINABLE_VARIABLES #=> 'trainable_variables'



本文主要分析GLOBAL_VARIABLES與LOCAL_VARIABLES之間的區別:

  • 對於區域性變數來說,其變數也是一種普通的變數,不過其定義在tf.GraphKeys.LOCAL_VARIABLES。通常該集合用於儲存程式用於初始化的預設變數列表,因此local指定的變數在預設情況下不會儲存。即不會儲存到checkpoint中
  • 定義一個區域性變數:需要顯示指定所在的變數集合collection
    e = tf.Variable(
    6, name='var_e', collections=[tf.GraphKeys.LOCAL_VARIABLES])





此處舉一個例子來表明TF的全域性變數及區域性變數:

import tensorflow as tf

def some_func():
    z = tf.Variable(1, name='var_z')

a = tf.Variable(1, name='var_a')
b = tf.get_variable('var_b', 2)
with tf.name_scope('aaa'):
    c = tf.Variable(3, name='var_c')

with tf.variable_scope('bbb'):
    d = tf.Variable(3, name='var_d')

some_func()
some_func()

print [str(i.name) for i in tf.global_variables()]
print [str(i.name) for i in tf.local_variables()]

結果:

['var_a:0', 'var_b:0', 'aaa/var_c:0', 'bbb/var_d:0', 'var_z:0', 'var_z_1:0']
[]