1. 程式人生 > >Tensorflow學習筆記--模型儲存與調取

Tensorflow學習筆記--模型儲存與調取

注:本文主要通過莫煩的python學習視訊記錄的內容,如果喜歡請支援莫煩python。謝謝

目前tf的模型儲存其實只是引數儲存,所以儲存檔案時你特別要主要以下幾點:

1、一定要設定好引數的資料型別!

2、設定引數的名稱,並且一一對應!

3、讀取引數時,需要設定好模型圖!

下面做一個簡單的demo,供各位參考:

儲存模型:

import tensorflow as tf
import numpy as np

## Save to file
# remember to define the same dtype and shape when restore
W = tf.Variable([[2,2,3],[3,4,5]], dtype=tf.float32, name='weights')
b = tf.Variable([[2,2,3]], dtype=tf.float32, name='biases')

init = tf.global_variables_initializer()

saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(init)
    save_path = saver.save(sess, "my_net/save_net2.ckpt")
    print("Save to path: ", save_path)

提取模型:
import tensorflow as tf
import numpy as np
# 先建立 W, b 的容器
W = tf.Variable(np.arange(6).reshape((2, 3)), dtype=tf.float32, name="weights")
b = tf.Variable(np.arange(3).reshape((1, 3)), dtype=tf.float32, name="biases")
#init = tf.global_variables_initializer()

saver = tf.train.Saver()
with tf.Session() as sess:
    #sess.run(init)
    # 提取變數
    saver.restore(sess, "my_net/save_net2.ckpt")
    print("weights:", sess.run(W))
    print("biases:", sess.run(b))

PS:

如果你讀取2次模型變數,你會發現以下錯誤:

NotFoundError (see above for traceback): Key weights_2 not found in checkpoint
	 [[Node: save_3/RestoreV2_11 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0"](_recv_save_3/Const_0, save_3/RestoreV2_11/tensor_names, save_3/RestoreV2_11/shape_and_slices)]]

原因:你可以看到再次讀取模型時,權重名稱已經變成了weights_2。所以就會報錯誤。