tensorflow儲存模型和恢復模型
阿新 • • 發佈:2018-12-28
儲存模型
w1 = tf.placeholder("float", name="w1") w2 = tf.placeholder("float", name="w2") b1= tf.Variable(2.0,name="bias") feed_dict ={w1:4,w2:8} w3 = tf.add(w1,w2) w4 = tf.multiply(w3,b1,name="op_to_restore") sess = tf.Session() sess.run(tf.global_variables_initializer()) #建立saver的例項 saver = tf.train.Saver() #列印w4 print(sess.run(w4,feed_dict)) #w4=(w1+w2)*b1,值為24 #儲存權重 saver.save(sess, 'my_test_model',global_step=1000)
恢復模型
import tensorflow as tf sess=tf.Session() #載入graph saver = tf.train.import_meta_graph('my_test_model-1000.meta') saver.restore(sess,tf.train.latest_checkpoint('./')) #直接訪問已儲存的變數 print(sess.run('bias:0')) # This will print 2, which is the value of bias that we saved #準備網路的輸入 graph = tf.get_default_graph() w1 = graph.get_tensor_by_name("w1:0") w2 = graph.get_tensor_by_name("w2:0") feed_dict ={w1:13.0,w2:17.0} #訪問想要執行的操作 op_to_restore = graph.get_tensor_by_name("op_to_restore:0") #打印出60 print(sess.run(op_to_restore,feed_dict))