【tensorflow】tensorflow中的全域性變數GLOBAL_VARIABLES及區域性變數LOCAL_VARIABLES
阿新 • • 發佈:2018-12-06
在初學tensorflow的時候,我們會發現在函式體內定義tf.variable()或者tf.get_variable()變數的時候,跟其他語言不同,在tensorflow的函式體內定義的變數並不會隨著函式的執行結束而消失。這是因為tensorflow設定的全域性變數及區域性變數與其他語言有著本質的不同,這是因為tf裡面是由圖定義的,內部的變數區分為
tf.GraphKeys.GLOBAL_VARIABLES #=> 'variables'
tf.GraphKeys.LOCAL_VARIABLES #=> 'local_variables'
tf.GraphKeys.MODEL_VARIABLES #=> 'model_variables'
tf.GraphKeys.TRAINABLE_VARIABLES #=> 'trainable_variables'
本文主要分析GLOBAL_VARIABLES與LOCAL_VARIABLES之間的區別:
- 對於區域性變數來說,其變數也是一種普通的變數,不過其定義在tf.GraphKeys.LOCAL_VARIABLES。通常該集合用於儲存程式用於初始化的預設變數列表,因此local指定的變數在預設情況下不會儲存。即不會儲存到checkpoint中。
- 定義一個區域性變數:需要顯示指定所在的變數集合collection
e = tf.Variable(
此處舉一個例子來表明TF的全域性變數及區域性變數:
import tensorflow as tf
def some_func():
z = tf.Variable(1, name='var_z')
a = tf.Variable(1, name='var_a')
b = tf.get_variable('var_b', 2)
with tf.name_scope('aaa'):
c = tf.Variable(3, name='var_c')
with tf.variable_scope('bbb'):
d = tf.Variable(3, name='var_d')
some_func()
some_func()
print [str(i.name) for i in tf.global_variables()]
print [str(i.name) for i in tf.local_variables()]
結果:
['var_a:0', 'var_b:0', 'aaa/var_c:0', 'bbb/var_d:0', 'var_z:0', 'var_z_1:0']
[]