1. 程式人生 > >tensorflow儲存模型和恢復模型

tensorflow儲存模型和恢復模型

儲存模型

w1 = tf.placeholder("float", name="w1")
w2 = tf.placeholder("float", name="w2")
b1= tf.Variable(2.0,name="bias")
feed_dict ={w1:4,w2:8}


w3 = tf.add(w1,w2)
w4 = tf.multiply(w3,b1,name="op_to_restore")
sess = tf.Session()
sess.run(tf.global_variables_initializer())

#建立saver的例項
saver = tf.train.Saver()

#列印w4
print(sess.run(w4,feed_dict))
#w4=(w1+w2)*b1,值為24

#儲存權重
saver.save(sess, 'my_test_model',global_step=1000)

恢復模型

import tensorflow as tf

sess=tf.Session()    
#載入graph
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))

#直接訪問已儲存的變數
print(sess.run('bias:0'))
# This will print 2, which is the value of bias that we saved

#準備網路的輸入
graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict ={w1:13.0,w2:17.0}

#訪問想要執行的操作 
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")
#打印出60
print(sess.run(op_to_restore,feed_dict))