1. 程式人生 > >tensorflow模型的儲存與載入

tensorflow模型的儲存與載入

1.儲存:(儲存的變數都是停放,tf.Variable()中的變數,變數一定要有名字)

saver = tf.train.Saver()

saver.run(sess,"./model4/line_model.ckpt")

 

2.檢視儲存的變數資訊:(將儲存的資訊打印出來)

from tensorflow.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
print_tensors_in_checkpoint_file("./model4/liner_model.ckpt",None,Ture)

 

3.將檔案中的引數載入到模型中:

saver.restore(sess,"./model4/line_model.ckpt")
w1 = sess.run(w,feed_dict={x:batch_xs,y:batch_ys})
b1 = sess.run(b,feed_dict={x:batch_xs,y:batch_ys})
print("w1:",w1)
print("b1:",b1)

 

完整程式碼奉上:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("MNIST_data/",one_hot=True)
#匯入資料集
x = tf.placeholder(shape=[None,784],dtype=tf.float32)
y = tf.placeholder(shape=[None,10],dtype=tf.float32)
#為輸入輸出定義placehloder

w = tf.Variable(tf.truncated_normal(shape=[784,10],mean=0,stddev=0.5),name="W")
b = tf.Variable(tf.zeros([10]),name="b")
#定義權重
y_pred = tf.nn.softmax(tf.matmul(x,w)+b)
#定義模型結構
loss =tf.reduce_mean(-tf.reduce_sum(y*tf.log(y_pred),reduction_indices=[1]))
#定義損失函式
opt = tf.train.GradientDescentOptimizer(0.05).minimize(loss)
#定義優化演算法
saver = tf.train.Saver()#儲存模型
sess =tf.Session()
sess.run(tf.global_variables_initializer())
for each in range(1000):
    batch_xs,batch_ys = mnist.train.next_batch(100)
    loss1 = sess.run(loss,feed_dict={x:batch_xs,y:batch_ys})
    opt1 = sess.run(opt,feed_dict={x:batch_xs,y:batch_ys})
    # print(loss1)
    # saver.save(sess,"./model4/line_model.ckpt")#將儲存的模型放在指定的檔案中
# w1 = sess.run(w,feed_dict={x:batch_xs,y:batch_ys})
# print(w1)
# b1 = sess.run(b,feed_dict={x:batch_xs,y:batch_ys})
# print("b:",b)
# y1 = sess.run(y_pred,feed_dict={x:batch_xs,y:batch_ys})
# print("y:",y1)
# print(len(y1))
#列印儲存的模型引數
# from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
# print_tensors_in_checkpoint_file("./model4/line_model.ckpt",None,True)

#將指定檔案中的變數載入到模型中
saver.restore(sess,"./model4/line_model.ckpt")
w1 = sess.run(w,feed_dict={x:batch_xs,y:batch_ys})
b1 = sess.run(b,feed_dict={x:batch_xs,y:batch_ys})
print("w1:",w1)
print("b1:",b1)