1. 程式人生 > >tensorflow 儲存和載入模型 -2

tensorflow 儲存和載入模型 -2

1、

我們經常在訓練完一個模型之後希望儲存訓練的結果,這些結果指的是模型的引數,以便下次迭代的訓練或者用作測試。Tensorflow針對這一需求提供了Saver類。
  1. Saver類提供了向checkpoints檔案儲存和從checkpoints檔案中恢復變數的相關方法Checkpoints檔案是一個二進位制檔案,它把變數名對映到對應的tensor值
  2. 只要提供一個計數器,當計數器觸發時,Saver類可以自動的生成checkpoint檔案。這讓我們可以在訓練過程中儲存多箇中間結果。例如,我們可以儲存每一步訓練的結果。
  3. 為了避免填滿整個磁碟,Saver可以自動的管理Checkpoints檔案。例如,我們可以指定儲存最近的N個Checkpoints檔案
2、code
import tensorflow as tf
import numpy as np

isTrain = True
train_steps = 100
checkpoint_steps = 50
checkpoint_dir = '/home/jdlu/jdluTensor/test/tmp/'

x = tf.placeholder(tf.float32, shape=[None, 1])
y = 4 * x + 4

w = tf.Variable(tf.random_normal([1], -1, 1))
b = tf.Variable(tf.zeros([1]))
y_predict = w * x + b


loss = tf.reduce_mean(tf.square(y - y_predict))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)

isTrain = False
train_steps = 100
checkpoint_steps = 50
checkpoint_dir = ''

saver = tf.train.Saver()  # defaults to saving all variables - in this case w and b
x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1))

with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    if isTrain:
        for i in xrange(train_steps):
            sess.run(train, feed_dict={x: x_data})
            if (i + 1) % checkpoint_steps == 0:
                saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=i+1)
    else:
        ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            pass
        print(sess.run(w))
        print(sess.run(b))

說明:

訓練的過程:

1、先設定isTrain=True,然後會儲存模型,設定isTrain=False會將訓練好的模型載入進來進行測試

2、train_steps:表示訓練的次數,例子中使用100
3、checkpoint_steps:表示訓練多少次儲存一下checkpoints,例子中使用50
4、checkpoint_dir:表示checkpoints檔案的儲存路徑,例子中使用當前路徑

    if isTrain:
        for i in xrange(train_steps):
            sess.run(train, feed_dict={x: x_data})
            if (i + 1) % checkpoint_steps == 0:
                saver.save(sess, checkpoint_dir + 'model.ckpt',global_step = i+1)
說明:每訓練checkpoint_steps就儲存一次模型,在訓練的過程中,就可以多次儲存模型。

測試的過程:

1、測試的過程就是載入訓練模型好的模型

 ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            pass
        print(sess.run(w))
        print(sess.run(b))

說明:

checkpoint的檔案內容:


儲存model的路徑下的檔案內容:


saver.save(sess, checkpoint_dir + 'model.ckpt',global_step = i+1)

每次儲存一次都會相應生成三個檔案,分別是.data-00000-of-00001,.index,.meta

==================================================================================================================