tensorflow模型的儲存與載入
阿新 • • 發佈:2018-11-02
1.儲存:(儲存的變數都是停放,tf.Variable()中的變數,變數一定要有名字)
saver = tf.train.Saver()
saver.run(sess,"./model4/line_model.ckpt")
2.檢視儲存的變數資訊:(將儲存的資訊打印出來)
from tensorflow.tools.inspect_checkpoint import print_tensors_in_checkpoint_file print_tensors_in_checkpoint_file("./model4/liner_model.ckpt",None,Ture)
3.將檔案中的引數載入到模型中:
saver.restore(sess,"./model4/line_model.ckpt")
w1 = sess.run(w,feed_dict={x:batch_xs,y:batch_ys})
b1 = sess.run(b,feed_dict={x:batch_xs,y:batch_ys})
print("w1:",w1)
print("b1:",b1)
完整程式碼奉上:
import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets("MNIST_data/",one_hot=True) #匯入資料集 x = tf.placeholder(shape=[None,784],dtype=tf.float32) y = tf.placeholder(shape=[None,10],dtype=tf.float32) #為輸入輸出定義placehloder w = tf.Variable(tf.truncated_normal(shape=[784,10],mean=0,stddev=0.5),name="W") b = tf.Variable(tf.zeros([10]),name="b") #定義權重 y_pred = tf.nn.softmax(tf.matmul(x,w)+b) #定義模型結構 loss =tf.reduce_mean(-tf.reduce_sum(y*tf.log(y_pred),reduction_indices=[1])) #定義損失函式 opt = tf.train.GradientDescentOptimizer(0.05).minimize(loss) #定義優化演算法 saver = tf.train.Saver()#儲存模型 sess =tf.Session() sess.run(tf.global_variables_initializer()) for each in range(1000): batch_xs,batch_ys = mnist.train.next_batch(100) loss1 = sess.run(loss,feed_dict={x:batch_xs,y:batch_ys}) opt1 = sess.run(opt,feed_dict={x:batch_xs,y:batch_ys}) # print(loss1) # saver.save(sess,"./model4/line_model.ckpt")#將儲存的模型放在指定的檔案中 # w1 = sess.run(w,feed_dict={x:batch_xs,y:batch_ys}) # print(w1) # b1 = sess.run(b,feed_dict={x:batch_xs,y:batch_ys}) # print("b:",b) # y1 = sess.run(y_pred,feed_dict={x:batch_xs,y:batch_ys}) # print("y:",y1) # print(len(y1)) #列印儲存的模型引數 # from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file # print_tensors_in_checkpoint_file("./model4/line_model.ckpt",None,True) #將指定檔案中的變數載入到模型中 saver.restore(sess,"./model4/line_model.ckpt") w1 = sess.run(w,feed_dict={x:batch_xs,y:batch_ys}) b1 = sess.run(b,feed_dict={x:batch_xs,y:batch_ys}) print("w1:",w1) print("b1:",b1)