本文是根據 TensorFlow 官方教程翻譯總結的學習筆記,主要介紹了在 TensorFlow 中如何共享引數變數。
教程中首先引入共享變數的應用場景,緊接著用一個例子介紹如何實現共享變數(主要涉及到 tf.variable_scope()
和tf.get_variable()
兩個介面),最後會介紹變數域 (Variable Scope) 的工作方式。
遇到的問題
假設我們建立了一個簡單的 CNN 網路:
def my_image_filter(input_images):
conv1_weights = tf.Variable(tf.random_normal([5, 5, 32, 32]),
name="conv1_weights")
conv1_biases = tf.Variable(tf.zeros([32]), name="conv1_biases")
conv1 = tf.nn.conv2d(input_images, conv1_weights,
strides=[1, 1, 1, 1], padding='SAME')
relu1 = tf.nn.relu(conv1 + conv1_biases)
conv2_weights = tf.Variable(tf.random_normal([5, 5, 32, 32]),
name="conv2_weights")
conv2_biases = tf.Variable(tf.zeros([32]), name="conv2_biases")
conv2 = tf.nn.conv2d(relu1, conv2_weights,
strides=[1, 1, 1, 1], padding='SAME')
return tf.nn.relu(conv2 + conv2_biases)
這個網路中用 tf.Variable()
初始化了四個引數。
不過,別看我們用一個函式封裝好了網路,當我們要呼叫網路進行訓練時,問題就會變得麻煩。比如說,我們有 image1
和 image2
兩張圖片,如果將它們同時丟到網路裡面,由於引數是在函式裡面定義的,這樣一來,每呼叫一次函式,就相當於又初始化一次變數:
# First call creates one set of 4 variables.
result1 = my_image_filter(image1)
# Another set of 4 variables is created in the second call.
result2 = my_image_filter(image2)
當然了,我們很快也能找到解決辦法,那就是把引數的初始化放在函式外面,把它們當作全域性變數,這樣一來,就相當於全域性「共享」了嘛。比如說,我們可以用一個 dict
在函式外定義引數:
variables_dict = {
"conv1_weights": tf.Variable(tf.random_normal([5, 5, 32, 32]),
name="conv1_weights")
"conv1_biases": tf.Variable(tf.zeros([32]), name="conv1_biases")
... etc. ...
}
def my_image_filter(input_images, variables_dict):
conv1 = tf.nn.conv2d(input_images, variables_dict["conv1_weights"],
strides=[1, 1, 1, 1], padding='SAME')
relu1 = tf.nn.relu(conv1 + variables_dict["conv1_biases"])
conv2 = tf.nn.conv2d(relu1, variables_dict["conv2_weights"],
strides=[1, 1, 1, 1], padding='SAME')
return tf.nn.relu(conv2 + variables_dict["conv2_biases"])
# The 2 calls to my_image_filter() now use the same variables
result1 = my_image_filter(image1, variables_dict)
result2 = my_image_filter(image2, variables_dict)
不過,這種方法對於熟悉面向物件的你來說,會不會有點彆扭呢?因為它完全破壞了原有的封裝。也許你會說,不礙事的,只要將引數和filter
函式都放到一個類裡即可。不錯,面向物件的方法保持了原有的封裝,但這裡出現了另一個問題:當網路變得很複雜很龐大時,你的引數列表/字典也會變得很冗長,而且如果你將網路分割成幾個不同的函式來實現,那麼,在傳參時將變得很麻煩,而且一旦出現一點點錯誤,就可能導致巨大的 bug。
為此,TensorFlow 內建了變數域這個功能,讓我們可以通過域名來區分或共享變數。通過它,我們完全可以將引數放在函式內部例項化,再也不用手動儲存一份很長的引數列表了。
用變數域實現共享引數
這裡主要包括兩個函式介面:
tf.get_variable(<name>, <shape>, <initializer>)
:根據指定的變數名例項化或返回一個tensor
物件;tf.variable_scope(<scope_name>)
:管理tf.get_variable()
變數的域名。
tf.get_variable()
的機制跟 tf.Variable()
有很大不同,如果指定的變數名已經存在(即先前已經用同一個變數名通過 get_variable()
函式例項化了變數),那麼 get_variable()
只會返回之前的變數,否則才創造新的變數。
現在,我們用 tf.get_variable()
來解決上面提到的問題。我們將卷積網路的兩個引數變數分別命名為 weights
和 biases
。不過,由於總共有 4 個引數,如果還要再手動加個 weights1
、weights2
,那程式碼又要開始噁心了。於是,TensorFlow 加入變數域的機制來幫助我們區分變數,比如:
def conv_relu(input, kernel_shape, bias_shape):
# Create variable named "weights".
weights = tf.get_variable("weights", kernel_shape,
initializer=tf.random_normal_initializer())
# Create variable named "biases".
biases = tf.get_variable("biases", bias_shape,
initializer=tf.constant_initializer(0.0))
conv = tf.nn.conv2d(input, weights,
strides=[1, 1, 1, 1], padding='SAME')
return tf.nn.relu(conv + biases)
def my_image_filter(input_images):
with tf.variable_scope("conv1"):
# Variables created here will be named "conv1/weights", "conv1/biases".
relu1 = conv_relu(input_images, [5, 5, 32, 32], [32])
with tf.variable_scope("conv2"):
# Variables created here will be named "conv2/weights", "conv2/biases".
return conv_relu(relu1, [5, 5, 32, 32], [32])
我們先定義一個 conv_relu()
函式,因為 conv 和 relu 都是很常用的操作,也許很多層都會用到,因此單獨將這兩個操作提取出來。然後在 my_image_filter()
函式中真正定義我們的網路模型。注意到,我們用 tf.variable_scope()
來分別處理兩個卷積層的引數。正如註釋中提到的那樣,這個函式會在內部的變數名前面再加上一個「scope」字首,比如:conv1/weights
表示第一個卷積層的權值引數。這樣一來,我們就可以通過域名來區分各個層之間的引數了。
不過,如果直接這樣呼叫 my_image_filter
,是會拋異常的:
result1 = my_image_filter(image1)
result2 = my_image_filter(image2)
# Raises ValueError(... conv1/weights already exists ...)
因為 tf.get_variable()
雖然可以共享變數,但預設上它只是檢查變數名,防止重複。要開啟變數共享,你還必須指定在哪個域名內可以共用變數:
with tf.variable_scope("image_filters") as scope:
result1 = my_image_filter(image1)
scope.reuse_variables()
result2 = my_image_filter(image2)
到這一步,共享變數的工作就完成了。你甚至都不用在函式外定義變數,直接呼叫同一個函式並傳入不同的域名,就可以讓 TensorFlow 來幫你管理變量了。
背後的工作方式
變數域的工作機理
接下來我們再仔細梳理一下這背後發生的事情。
我們要先搞清楚,當我們呼叫 tf.get_variable(name, shape, dtype, initializer)
時,這背後到底做了什麼。
首先,TensorFlow 會判斷是否要共享變數,也就是判斷 tf.get_variable_scope().reuse
的值,如果結果為 False
(即你沒有在變數域內呼叫scope.reuse_variables()
),那麼 TensorFlow 認為你是要初始化一個新的變數,緊接著它會判斷這個命名的變數是否存在。如果存在,會丟擲 ValueError
異常,否則,就根據 initializer
初始化變數:
with tf.variable_scope("foo"):
v = tf.get_variable("v", [1])
assert v.name == "foo/v:0"
而如果 tf.get_variable_scope().reuse == True
,那麼 TensorFlow 會執行相反的動作,就是到程式裡面尋找變數名為 scope name + name
的變數,如果變數不存在,會丟擲 ValueError
異常,否則,就返回找到的變數:
with tf.variable_scope("foo"):
v = tf.get_variable("v", [1])
with tf.variable_scope("foo", reuse=True):
v1 = tf.get_variable("v", [1])
assert v1 is v
瞭解變數域背後的工作方式後,我們就可以進一步熟悉其他一些技巧了。
變數域的基本使用
變數域可以巢狀使用:
with tf.variable_scope("foo"):
with tf.variable_scope("bar"):
v = tf.get_variable("v", [1])
assert v.name == "foo/bar/v:0"
我們也可以通過 tf.get_variable_scope()
來獲得當前的變數域物件,並通過 reuse_variables()
方法來設定是否共享變數。不過,TensorFlow 並不支援將 reuse
值設為 False
,如果你要停止共享變數,可以選擇離開當前所在的變數域,或者再進入一個新的變數域(比如,再進入一個 with
語句,然後指定新的域名)。
還需注意的一點是,一旦在一個變數域內將 reuse
設為 True
,那麼這個變數域的子變數域也會繼承這個 reuse
值,自動開啟共享變數:
with tf.variable_scope("root"):
# At start, the scope is not reusing.
assert tf.get_variable_scope().reuse == False
with tf.variable_scope("foo"):
# Opened a sub-scope, still not reusing.
assert tf.get_variable_scope().reuse == False
with tf.variable_scope("foo", reuse=True):
# Explicitly opened a reusing scope.
assert tf.get_variable_scope().reuse == True
with tf.variable_scope("bar"):
# Now sub-scope inherits the reuse flag.
assert tf.get_variable_scope().reuse == True
# Exited the reusing scope, back to a non-reusing one.
assert tf.get_variable_scope().reuse == False
捕獲變數域物件
如果一直用字串來區分變數域,寫起來容易出錯。為此,TensorFlow 提供了一個變數域物件來幫助我們管理程式碼:
with tf.variable_scope("foo") as foo_scope:
v = tf.get_variable("v", [1])
with tf.variable_scope(foo_scope)
w = tf.get_variable("w", [1])
with tf.variable_scope(foo_scope, reuse=True)
v1 = tf.get_variable("v", [1])
w1 = tf.get_variable("w", [1])
assert v1 is v
assert w1 is w
記住,用這個變數域物件還可以讓我們跳出當前所在的變數域區域:
with tf.variable_scope("foo") as foo_scope:
assert foo_scope.name == "foo"
with tf.variable_scope("bar")
with tf.variable_scope("baz") as other_scope:
assert other_scope.name == "bar/baz"
with tf.variable_scope(foo_scope) as foo_scope2:
assert foo_scope2.name == "foo" # Not changed.
在變數域內初始化變數
每次初始化變數時都要傳入一個 initializer
,這實在是麻煩,而如果使用變數域的話,就可以批量初始化引數了:
with tf.variable_scope("foo", initializer=tf.constant_initializer(0.4)):
v = tf.get_variable("v", [1])
assert v.eval() == 0.4 # Default initializer as set above.
w = tf.get_variable("w", [1], initializer=tf.constant_initializer(0.3)):
assert w.eval() == 0.3 # Specific initializer overrides the default.
with tf.variable_scope("bar"):
v = tf.get_variable("v", [1])
assert v.eval() == 0.4 # Inherited default initializer.
with tf.variable_scope("baz", initializer=tf.constant_initializer(0.2)):
v = tf.get_variable("v", [1])
assert v.eval() == 0.2 # Changed default initializer.