1. 程式人生 > >Tensorflow模型儲存和過載

Tensorflow模型儲存和過載

最近因為專案要求,需要把模型的訓練和測試過程分開,這裡主要涉及兩個過程:訓練圖的存取和引數的存取。
以下所有/home/yy/xiajbxie/model是我的模型的儲存路徑,將其換成你自己的即可。

tf.train.Saver()

Saver的作用中文社群已經講得相當清楚。tf.train.Saver()類的基本操作時save()和restore()函式,分別負責模型引數的儲存和恢復。引數儲存示例如下:

import tensorflow as tf

# Create some variables.
v1 = tf.Variable(tf.random_normal([1, 2]), name="v1"
) v2 = tf.Variable(tf.random_normal([2, 3]), name="v2") # Add an op to initialize the variables. init_op = tf.initialize_all_variables() # Add ops to save and restore all the variables. saver = tf.train.Saver() # initialize the variables, save the variables to disk. with tf.Session() as sess: sess.run(init_op) v1, v2 = sess.run([v1, v2]) print(v1) print(v2) # Do some work with the model.
# Save the variables to disk. save_path = saver.save(sess, "/home/yy/xiajbxie/model") print "Model saved in file: ", save_path

執行結果:

[[-0.0493206   0.12752049]]
[[ 1.9456626   0.6319563  -0.1296857 ]
 [-0.7834143   0.33656874 -0.96077037]]
Model saved in file:  /home/yy/xiajbxie/model

引數恢復示例如下:

import tensorflow as
tf # Create some variables. v1 = tf.Variable(tf.random_normal([1, 2]), name="v1") v2 = tf.Variable(tf.random_normal([2, 3]), name="v2") # Add ops to save and restore all the variables. saver = tf.train.Saver() # Later, launch the model, use the saver to restore variables from disk, and # do some work with the model. with tf.Session() as sess: # Restore variables from disk. saver.restore(sess, "/home/yy/xiajbxie/model") print "Model restored." print(sess.run([v1, v2]))

執行結果:

Model restored.
[array([[-0.0493206 ,  0.12752049]], dtype=float32), array([[ 1.9456626 ,  0.6319563 , -0.1296857 ],
       [-0.7834143 ,  0.33656874, -0.96077037]], dtype=float32)]

saver.save()函式的引數為需儲存的會話,以及模型的儲存路徑。儲存後我們進入模型的儲存路徑會看到4個新增檔案,4個檔案根據tensorflow版本不同名字不同,以上例為例,1.2版本4個檔案如下:
1. checkpoint:其中儲存模型所在的路徑
2. model.meta:包含計算圖的完整資訊
3. model.index:與下面的檔案一起儲存所有的變數值
4. model.data-00000-of-00001

可以看到,在模型引數恢復前需事先定義要恢復的變數,並且變數名需要與模型中儲存的變數名保持一致。
官方文件的說法是無需在引數恢復前對其進行初始化,但實際操作的時候有出現過報錯“FailedPreconditionError (see above for traceback): Attempting to use uninitialized value”的情況,此時利用tf.global_variables_initializer()初始化變數可解決問題。

tf.train.import_meta_graph()

模型引數恢復之前需要先定義模型中儲存的變數,如果不想這樣做可以把模型的計算圖也恢復出來。tf.train.import_meta_graph()函式就用於恢復模型,它的輸入引數為模型路徑,返回一個Saver類例項,再呼叫這個例項的restore()函式就可以恢復其引數了。示例如下:

import tensorflow as tf

sess = tf.Session()
new_saver = tf.train.import_meta_graph('/home/yy/xiajbxie/model.meta')

with tf.Session() as sess:
    ckpt = tf.train.get_checkpoint_state('/home/yy/xiajbxie')
    if ckpt and ckpt.model_checkpoint_path:
        print ckpt.model_checkpoint_path
        new_saver.restore(sess, ckpt.model_checkpoint_path)

    v1 = tf.get_default_graph().get_tensor_by_name('v1:0')
    v2 = tf.get_default_graph().get_tensor_by_name('v2:0')
    print(sess.run([v1, v2]))

執行結果:

/home/yy/xiajbxie/model
[array([[-0.0493206 ,  0.12752049]], dtype=float32), array([[ 1.9456626 ,  0.6319563 , -0.1296857 ],
       [-0.7834143 ,  0.33656874, -0.96077037]], dtype=float32)]

其中get_checkpoint_state()用於在傳入的路徑中尋找tensorflow檢查點。

tips

  • 在不知道要過載的tensor叫什麼名字時可以在訓練階段列印變數名來觀察。
  • 不能在與訓練資料相同的計算圖下載入以前儲存的計算圖,如果實在要這樣做也要保證兩個計算圖中不包含名字相同的變數。
  • 利用tf.Graph()來生成新的計算圖,利用tf.Graph().as_default()來將新生成的計算圖設定為預設。