1. 程式人生 > >Tensorflow模型引數的Saver儲存讀取

Tensorflow模型引數的Saver儲存讀取

一、Saver儲存

import tensorflow as tf
import numpy as np

#定義W和b
W = tf.Variable([[1,2,3],[3,5,6]],dtype = tf.float32,name = 'weight')
b = tf.Variable([1,2,3],dtype = tf.float32,name = 'biases')
#注:初始化變數Variable
init = tf.global_variables_initializer()


#建立tf.train.Saver() 來儲存, 提取變數。
#建立my_net資料夾,儲存變數
saver =  tf.train.Saver()

sess = tf.Session()
sess.run(init)
#儲存變數到路徑my_net
save_path = saver.save(sess,"my_net/save_net.ckpt")#儲存格式為ckpt

#輸出儲存的變數
print("save path:",save_path)

結果: 

二、Saver讀取

import tensorflow as tf
import numpy as np


#建立W,b的空容器
W = tf.Variable(np.arange(6).reshape((2, 3)), dtype=tf.float32, name="weights")
b = tf.Variable(np.arange(3).reshape((1, 3)), dtype=tf.float32, name="biases")

#不需要初始化變數

saver = tf.train.Saver()
with tf.Session() as sess:
    saver.restore(sess, "my_net/save_net.ckpt")
    print("weights:", sess.run(W))
    print("biases:", sess.run(b))