Tensorflow學習筆記--模型儲存與調取
阿新 • • 發佈:2019-02-16
注:本文主要通過莫煩的python學習視訊記錄的內容,如果喜歡請支援莫煩python。謝謝
目前tf的模型儲存其實只是引數儲存,所以儲存檔案時你特別要主要以下幾點:
1、一定要設定好引數的資料型別!
2、設定引數的名稱,並且一一對應!
3、讀取引數時,需要設定好模型圖!
下面做一個簡單的demo,供各位參考:
儲存模型:
import tensorflow as tf import numpy as np ## Save to file # remember to define the same dtype and shape when restore W = tf.Variable([[2,2,3],[3,4,5]], dtype=tf.float32, name='weights') b = tf.Variable([[2,2,3]], dtype=tf.float32, name='biases') init = tf.global_variables_initializer() saver = tf.train.Saver() with tf.Session() as sess: sess.run(init) save_path = saver.save(sess, "my_net/save_net2.ckpt") print("Save to path: ", save_path)
提取模型:
import tensorflow as tf import numpy as np # 先建立 W, b 的容器 W = tf.Variable(np.arange(6).reshape((2, 3)), dtype=tf.float32, name="weights") b = tf.Variable(np.arange(3).reshape((1, 3)), dtype=tf.float32, name="biases") #init = tf.global_variables_initializer() saver = tf.train.Saver() with tf.Session() as sess: #sess.run(init) # 提取變數 saver.restore(sess, "my_net/save_net2.ckpt") print("weights:", sess.run(W)) print("biases:", sess.run(b))
PS:
如果你讀取2次模型變數,你會發現以下錯誤:
NotFoundError (see above for traceback): Key weights_2 not found in checkpoint
[[Node: save_3/RestoreV2_11 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0"](_recv_save_3/Const_0, save_3/RestoreV2_11/tensor_names, save_3/RestoreV2_11/shape_and_slices)]]
原因:你可以看到再次讀取模型時,權重名稱已經變成了weights_2。所以就會報錯誤。