1. 程式人生 > >tensorflow 訓練模型的儲存 與 讀取已儲存的模型進行測試

tensorflow 訓練模型的儲存 與 讀取已儲存的模型進行測試

在實際中,通常需要將經過大量訓練的較好模型引數儲存起來,在實際應用以訓練好的模型進行預測。

TensorFlow中提供了模型儲存的模組 tensorflow.train.Saver()

1. 匯入tensorflow模組                   import tensorflow as tf

2. 建立模型儲存的Saver物件      saver = tf.train.Saver

3. 儲存訓練好的模型,設定模型儲存的路徑 checkpoint_dir =  './model/' , 其中model是當前路徑下儲存模型的資料夾名稱

     saver.save(sess, checkpoint_dir+'model.ckpt', global_step = step) ,model.ckpt-step 是模型的檔名,step是迭代次數。

    需注意的是,最後一次迭代的訓練模型有可能不是準確度最高的一次,如果想儲存迭代中準確度最高的一次,需要新增判斷。

    在迭代訓練前設定初始最大準確度 max_acc = 0

    在每次迭代中進行判斷  

checkpoint_dir = './model/'
if val_acc > max_acc:
    max_acc = val_acc
    saver.save(sess, checkpoint_dir+'model.ckpt', global_step = step)

 4. 用已儲存的模型進行測試 

     model_file = tf.train.latest_checkpoint(checkpoint_dir)

     saver.restore(sess, model_file)

     output = sess.run(pre_result, feed_dict={x: test_x})