1. 程式人生 > >TensorFlow模型檔案儲存和讀取

TensorFlow模型檔案儲存和讀取

一、模型檔案的儲存

在訓練一個TensorFlow模型之後,我們可以將訓練好的模型儲存成檔案,這樣可以方便下一次對新的資料進行預測的時候直接載入訓練好的模型即可獲得結果,下面通過TensorFlow提供的tf.train.Saver函式,將一個模型儲存成檔案,一般習慣性的將TensorFlow的模型檔案命名為*.ckpt檔案。

[python]  view plain  copy
  1. <span style="font-size:14px;">import tensorflow as tf  
  2.   
  3. if __name__ == "__main__":  
  4.     #定義兩個變數  
  5.     a = tf.Variable(tf.constant(1.0,shape=[1],name="a"))  
  6.     b = tf.Variable(tf.constant(2.0,shape=[1],name=
    "b"))  
  7.     c = a + b  
  8.   
  9.     init = tf.initialize_all_variables()  
  10.     sess = tf.Session()  
  11.     sess.run(init)  
  12.   
  13.     #宣告一個儲存  
  14.     saver = tf.train.Saver()  
  15.     saver.save(sess,"./model.ckpt")</span>  
如果,在執行程式的時候報ValueError: Parent directory of model.ckpt doesn't exist, can't save.,只需要將儲存檔案的路徑由model.ckpt改成./model.ckpt即可。執行完上面的程式碼之後,我們會發現在當前的程式目錄下產生四個檔案checkpointmodel.ckpt.data-00000-of-00001model.ckpt.indexmodel.ckpt.meta。會產生四個檔案的原因,之前有介紹過TensorFlow的程式是由計算圖所組成的,所以在持久化的時候TensorFlow會將計算圖的結果和圖上的引數值分成不同的檔案進行儲存。

二、模型檔案的讀取

TensorFlow對於模型檔案的讀取方式也提供了幾種方法,根據讀取不同的檔案來獲取不同的資訊。

1、載入model.ckpt檔案來初始化變數

[python]  view plain  copy
  1. <span style="font-size:14px;">    a = tf.Variable(tf.constant(3.0,shape=[1],name="a"))  
  2.     b = tf.Variable(tf.constant(4.0,shape=[1],name="b"))  
  3.     c = a + b  
  4.     saver = tf.train.Saver()  
  5.     sess = tf.Session()  
  6.     saver.restore(sess,"model.ckpt")  
  7.     print(sess.run(c))  
  8.     #[ 3.]</span>  
在宣告變數的時候,變數的名字,shape要與儲存的模型檔案一致,無論給變數a和b的初始值設定為什麼,最後輸出的結果總是3,因為在儲存模型檔案的時候,已經記錄了變數的初始值。而且,在載入model.ckpt檔案的時候也不需要對變數進行初始化操作。

2、載入持久化圖獲取全部變數

[python]  view plain  copy
  1. <span style="font-size:14px;">    saver = tf.train.import_meta_graph("model.ckpt.meta")  
  2.     sess = tf.Session()  
  3.     saver.restore(sess,"model.ckpt")  
  4.     print(sess.run(tf.get_default_graph().get_tensor_by_name("a:0")))  
  5.     #[ 1.]  
  6.     print(sess.run(tf.get_default_graph().get_tensor_by_name("b:0")))  
  7.     #[ 2.]  
  8.     print(sess.run(tf.get_default_graph().get_tensor_by_name("add:0")))  
  9.     #[ 3.]</span>  
通過載入model.cpt.meta檔案和model.ckpt檔案來獲取全部的變數,然後通過變數的名稱來獲取變數的值,在通過變數的名字來獲取變數的時候需要注意的是,比如說變數a的name為"a",但是在使用名字的時候不能直接使用“a”來獲取變數的值,如果直接使用“a”的話,會報ValueError: The name 'a' refers to an Operation, not a Tensor. Tensor names must be of the form "<op_name>:<output_index>".錯誤的原因就是需要以<op_name>:<output_index>來獲取變數的值,意思就是名字和下標結合,如"a:0"。還需要注意的就是在獲取變數c的值的時候,不是通過c,而是通過"add:0",因為變數a和b直接求和,生成的變數c的名字TensorFlow預設為"add"。

3、載入指定列表變數

[python]  view plain  copy
  1. <span style="font-size:14px;">    a = tf.Variable(tf.constant(3.0,shape=[1],name="a"))  
  2.     b = tf.Variable(tf.constant(4.0,shape=[1],name="b"))  
  3.     c = a + b  
  4.     saver = tf.train.Saver([a,b])  
  5.     sess = tf.Session()  
  6.     saver.restore(sess,"model.ckpt")  
  7.     print(sess.run(a))  
  8.     #[ 1.]  
  9.     print(sess.run(b))  
  10.     #[ 2.]</span>  
通過在第一種方式的基礎上,初始化Saver的時候指定一個列表,在初始化模型檔案中的變數時,只會載入指定列表的變數。如果在上面程式碼的基礎上,在最後在加一句print(sess.run(c)),輸出的結果為[3.],明明沒有指定載入c,為什麼還能輸出3呢?其實,原因也很簡單,因為我們已經初始化了變數a和b,所以通過計算a+b自然就可以計算出c了。如果,指定載入列表為[a]而輸出b的話,就會報tensorflow.python.framework.errors_impl.FailedPreconditionError: Attempting to use uninitialized value Variable_1
[[Node: _retval_Variable_1_0_0 = _Retval[T=DT_FLOAT, index=0, _device="/job:localhost/replica:0/task:0/cpu:0"](Variable_1)]]
,使用一個沒有初始化的變數。

4、載入變數名的重新命名

tensorfow提供了一種方法可以修改載入模型中的變數名,通過tf.train.Saver(),帶參的形式來修改變數名稱。

[python]  view plain  copy
  1. <span style="font-size:14px;">   #重新定義兩個變數v1和v2  
  2.     v1 = tf.Variable(tf.constant(3.,shape=[1]),name="v1")  
  3.     v2 = tf.Variable(tf.constant(4.,shape=[1]),name="v2")  
  4.     #將模型中的變數名a重新命名為v1,將模型中的變數名b重新命名為v2  
  5.     save = tf.train.Saver({"a":v1,"b":v2})  
  6.     sess = tf.Session()  
  7.     save.restore(sess,"model.ckpt")  
  8.     print(sess.run(v1))  
  9.     print(sess.run(v2))</span>  
通過傳入一個字典,來修改TensorFlow的變數名,a和b是模型中的變數名稱,而v1和v2是將變數a和b重新命名之後的名稱。如果,你用我的第一個程式來儲存一個模型檔案,通過上面的方法來修改變數的名稱的時候,你會得到一個錯誤NotFoundError (see above for traceback): Key b not found in checkpoint難道上面的程式碼有問題?其實,這個坑是在儲存模型檔案留下來的。要想解決這個問題,首先還是看重新命名變數名稱的這個程式,TensorFlow提供的tf.train.Saver({"a":v1,"b":v2})方法,它會去checkpoint這個檔案中找變數名為a和b的變數,然後再修改變數名,找不到這兩個變數自然就報錯了。第一個程式的坑,在這句程式碼中tf.Variable(tf.constant(1.0,shape=[1],name="a")),也許細心的朋友已經發現問題了,其實我們只是將常量命名為了"a",並沒有將變數命名,這樣就導致了問題的發現。所以,要想解決這個問題,我們只需要將tf.tran.Saver中的{“a”:v1,"b":v2}修改為{"Variable":v1,"Variable_1":v2}或者將tf.Variable(tf.constant(1.0,shape=[1],name="a"))修改為tf.Variable(tf.constant(1.0,shape=[1]),name="a"),其中的Variable和Variable_1是TensorFlow預設的變數名稱,我們可以通過a.name的方式檢視TensorFlow中的變數名稱。