tensorflow 訓練模型的儲存 與 讀取已儲存的模型進行測試
阿新 • • 發佈:2018-11-10
在實際中,通常需要將經過大量訓練的較好模型引數儲存起來,在實際應用以訓練好的模型進行預測。
TensorFlow中提供了模型儲存的模組 tensorflow.train.Saver()
1. 匯入tensorflow模組 import tensorflow as tf
2. 建立模型儲存的Saver物件 saver = tf.train.Saver
3. 儲存訓練好的模型,設定模型儲存的路徑 checkpoint_dir = './model/' , 其中model是當前路徑下儲存模型的資料夾名稱
saver.save(sess, checkpoint_dir+'model.ckpt', global_step = step) ,model.ckpt-step 是模型的檔名,step是迭代次數。
需注意的是,最後一次迭代的訓練模型有可能不是準確度最高的一次,如果想儲存迭代中準確度最高的一次,需要新增判斷。
在迭代訓練前設定初始最大準確度 max_acc = 0
在每次迭代中進行判斷
checkpoint_dir = './model/' if val_acc > max_acc: max_acc = val_acc saver.save(sess, checkpoint_dir+'model.ckpt', global_step = step)
4. 用已儲存的模型進行測試
model_file = tf.train.latest_checkpoint(checkpoint_dir)
saver.restore(sess, model_file)
output = sess.run(pre_result, feed_dict={x: test_x})