1. 程式人生 > >Tensorflow學習筆記(一)--變數作用域與模型載入

Tensorflow學習筆記(一)--變數作用域與模型載入

1、變數作用域機制主要由兩個函式實現:

tf.get_variable(<name>, <shape>, <initializer>)
tf.variable_scope(<scope_name>)

2、常用的initializer有

tf.constant_initializer(value) # 初始化一個常量值,
tf.random_uniform_initializer(a, b) # 從a到b均勻分佈的初始化,
tf.random_normal_initializer(mean, stddev) # 用所給平均值和標準差初始化正態分佈.

3、變數作用域的tf.variable_scope()帶有一個名稱,它將會作為字首用於變數名,並且帶有一個重用標籤(後面會說到)來區分以上的兩種情況。巢狀的作用域附加名字所用的規則和檔案目錄的規則很類似。

對於採用了變數作用域的網路結構,結構虛擬碼如下:

import tensorflow as tf 

def my_image_filter():
    with tf.variable_scope("conv1"):
        weights = tf.get_variable("weights", [1], initializer=tf.random_normal_initializer())
    print("weights:%s" % weights.name)
    with tf.variable_scope("conv2"):
        biases = tf.get_variable("biases", [1], initializer=tf.constant_initializer(0.3))
    print("biases:%s" % biases.name)
	
result1 = my_image_filter()


輸出:

weights:conv1/weights:0
biases:conv2/biases:0

4、如果連續呼叫兩次my_image_filter()將會報出ValueError:

result1 = my_image_filter()
result2 = my_image_filter()

ValueError: Variable conv1/weights already exists, disallowed. Did you mean to s
et reuse=True or reuse=tf.AUTO_REUSE in VarScope? Originally defined at:

解決方案: 若不在網路架構中採用變數作用域則不會報錯,但是會產生兩組變數,而不是共享變數。

a、 當tf.get_variable_scope().reuse == True時;該情況下會搜尋一個已存在的“foo/v”並將該變數的值賦給v1,若找不到“foo/v”變數則會丟擲ValueError。

b、當tf.get_variable_scope().reuse == tf.AUTO_REUSE時,該方法是為重用變數所設定;該情況不會丟擲ValueError

import tensorflow as tf 

def my_image_filter():
    with tf.variable_scope("conv1", reuse=tf.AUTO_REUSE):
    # Variables created here will be named "conv1/weights", "conv1/biases".
        weights = tf.get_variable("weights", [1], initializer=tf.random_normal_initializer())
    print("weights:%s" % weights.name)
    with tf.variable_scope("conv2", reuse=tf.AUTO_REUSE):
    # Variables created here will be named "conv2/weights", "conv2/biases".
        biases = tf.get_variable("biases", [1], initializer=tf.constant_initializer(0.3))
    print("biases:%s" % biases.name)
	
result1 = my_image_filter()
result2 = my_image_filter()

5、 在模型載入時,如果網路框架中採用變數作用域,也會出現該問題:Variable conv1/weights already exists disallowed. Did you mean to set reuse=True

解決方案:

如果Restart kernel 之後再次執行就不會有問題了(相當於重啟了spyder,這樣不能從根本解決問題。而且多次重啟,也不太好。)

這個問題主要是由於再次執行的時候,之前的計算圖已經存在了,再次執行時會和之前已經存在的產生衝突。解決方法:
在程式碼前面加一句:tf.reset_default_graph()

tf.reset_default_graph()
ckpt_file = tf.train.latest_checkpoint(model_path)
print(ckpt_file)
paths['model_path'] = ckpt_file
model = BiLSTM_CRF(args, embeddings, tag2label, word2id, paths, config=config)
model.build_graph()

參考文獻: