1. 程式人生 > >【TensorFlow學習筆記】4:認識Variable及其重用(共享),在scope上的初始化

【TensorFlow學習筆記】4:認識Variable及其重用(共享),在scope上的初始化

學習《深度學習之TensorFlow》時的一些實踐。


認識TF中的Variable

TF通過name來標識變數(Variable),這和呼叫者定義的程式裡的"變數名"無關。當不指定name時,由TF自己指定,當建立的變數的name已經存在時,TF會為其改名。

變數的建立和name指定

# 兩個未命名的變數,TF會自動給名字
a = tf.Variable(1.0)
print("a:", a.name)
b = tf.Variable(2.0)
print("b:", b.name)
# 兩個name一樣的變數,TF會為第二個改名字
c = tf.Variable(
3.0, name='var') print("c:", c.name) d = tf.Variable(4.0, name='var') print("d:", d.name)

a: Variable:0
b: Variable_1:0
c: var:0
d: var_1:0

讀取變數的值

# 在Session中讀取變數的值
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print("a:", a.eval())
    print("b:", b.eval())
    print
("c:", c.eval()) print("d:", d.eval())

a: 1.0
b: 2.0
c: 3.0
d: 4.0

使用tf.get_variable()

通過一系列引數,獲取一個已經存在的變數,或者建立一個新的變數。

下面即建立一個名為ok,shape為[1]的變數,初始化為6.6。

ok = tf.get_variable("ok", [1], initializer=tf.constant_initializer(6.6))
print("ok:", ok.name)

ok: ok:0

變數的scope

要使用相同name的變數,一般要指定在不同的scope裡。當使用tf.get_variable()

建立變數時,會去檢查計算任務中是否已經建立過這個變數,如果建立過了,而且本次沒有使用共享方式,就會出錯。

with tf.variable_scope("v1"):
    ok = tf.get_variable("ok", [1], initializer=tf.constant_initializer(6.6))
    print("ok:", ok.name)

with tf.variable_scope("v2"):
    ok = tf.get_variable("ok", [1], initializer=tf.constant_initializer(6.6))
    print("ok:", ok.name)

ok: v1/ok:0
ok: v2/ok:0

巢狀scope

變數作用域可以巢狀。

with tf.variable_scope("v4"):
    ok4 = tf.get_variable("ok", [1], initializer=tf.constant_initializer(6.6))
    with tf.variable_scope("v5"):
        ok45 = tf.get_variable("ok", [1], initializer=tf.constant_initializer(6.6))

print("ok4:", ok4.name)
print("ok45:", ok45.name)

ok4: v4/ok:0
ok45: v4/v5/ok:0

Variable的重用

指向同一個Variable的程式變數即重用(共享)了。這在有些需要協作的模型(如GAN)裡是比較關鍵的。

reuse_variables()

用上面的方式指定的兩個變數是不同的,如果要在一個作用域裡指定兩個變數是相同的,可以在該作用域下開啟變數重用。

with tf.variable_scope("v3") as v3:
    p1 = tf.get_variable("p", [1], initializer=tf.constant_initializer(2.2))
    print("p1:", p1.name)
    v3.reuse_variables()  # 開啟變數重用
    p2 = tf.get_variable("p")
    print("p2:", p2.name)

p1: v3/p:0
p2: v3/p:0

# 檢視這兩個變數的值
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(p1), sess.run(p2))

[2.2] [2.2]

共享變數的一般用法

一般是將共享的一個作用域裡的變數放到兩個網路中去,也就是在某個網路中:

with tf.variable_scope("v6"):
    g1 = tf.get_variable("g1", [1], initializer=tf.constant_initializer(1.1))
    with tf.variable_scope("v7"):
        g2 = tf.get_variable("g2", [1], initializer=tf.constant_initializer(2.2))

在另一個網路中,相應的scope開啟reuse(這裡的reuse可以級聯傳遞):

with tf.variable_scope("v6", reuse=True):
    g3 = tf.get_variable("g1", [1], initializer=tf.constant_initializer(1.1))
    with tf.variable_scope("v7"):
        g4 = tf.get_variable("g2", [1], initializer=tf.constant_initializer(2.2))

最終g1和g3共享了,g2和g4共享了:

print(g1.name, g3.name)
print(g2.name, g4.name)

v6/g1:0 v6/g1:0
v6/v7/g2:0 v6/v7/g2:0

關於自動重用

如果在第一個v6上開啟reuse=True,那麼會報錯:

ValueError: Variable v6/g1 does not exist, or was not created with tf.get_variable(). Did you mean to set reuse=tf.AUTO_REUSE in VarScope?

也就是說,開啟重用的scope會去找裡面先前已經定義好的變數,並重用之,如果找不到就會出問題。

在某些情況下,開啟自動重用(reuse=tf.AUTO_REUSE)是比較合適的,它可以兼顧“重用”和“建立新變數”,即能重用就重用,不能就直接建立新變數。

with tf.variable_scope("v3", reuse=tf.AUTO_REUSE) as v3:
    p1 = tf.get_variable("p", [1], initializer=tf.constant_initializer(2.2))
    print("p1:", p1.name)
    p2 = tf.get_variable("p")
    print("p2:", p2.name)

p1: v3/p:0
p2: v3/p:0

Variable在scope上的初始化

在scope上可以指定initializer,對於未指定initializer的Variable和子scope,它會級聯傳遞;對於指定了initializer的子scope,它會被覆蓋,並以覆蓋後的值向下級聯傳遞。

with tf.variable_scope("s1", initializer=tf.constant_initializer(1.1)):
    s1a = tf.get_variable("s1a", [1])
    s1b = tf.get_variable("s1b", [1], initializer=tf.constant_initializer(1.2))
    with tf.variable_scope("s2"):
        s2a = tf.get_variable("s1a", [1])
    with tf.variable_scope("s3", initializer=tf.constant_initializer(2.1)):
        s3a = tf.get_variable("s1a", [1])

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(s1a), sess.run(s1b), sess.run(s2a), sess.run(s3a))

[1.1] [1.2] [1.1] [2.1]