1. 程式人生 > >tensorflow-模型儲存和載入(一)

tensorflow-模型儲存和載入(一)

模型儲存和載入(一)

TensorFlow的模型格式有很多種,針對不同場景可以使用不同的格式。

格式簡介
Checkpoint用於儲存模型的權重,主要用於模型訓練過程中引數的備份和模型訓練熱啟動。
GraphDef用於儲存模型的Graph,不包含模型權重,加上checkpoint後就有模型上線的全部資訊。
SavedModel使用saved_model介面匯出的模型檔案,包含模型Graph和許可權可直接用於上線,TensorFlow和Keras模型推薦使用這種模型格式。
FrozenGraph使用freeze_graph.py對checkpoint和GraphDef進行整合和優化,可以直接部署到Android、iOS等移動裝置上。
TFLite基於flatbuf對模型進行優化,可以直接部署到Android、iOS等移動裝置上,使用介面和FrozenGraph有些差異。

在訓練模型的時候需要儲存模型的中間訓練結果Checkpoint,以便下次迭代訓練或者用作預測。Tensorflow針對這一需求提供了Saver類

1.Saver類提供了向checkpoints檔案儲存和從checkpoints檔案中恢復變數的相關方法。Checkpoints檔案是一個二進位制檔案,它把變數名對映到對應的tensor值。
2.只要提供一個計數器,當計數器觸發時,Saver類可以自動的生成checkpoint檔案。這讓我們可以在訓練過程中儲存多箇中間結果。例如,我們可以儲存每一步訓練的結果。
3.為了避免填滿整個磁碟,Saver可以自動的管理Checkpoints檔案。例如,我們可以指定儲存最近的N個Checkpoints檔案。

方式一:

模型儲存:

通過下面的一段程式碼建立saver物件來管理模型中的變數(預設情況下是所有的變數,也可以自行選擇)。

import tensorflow as tf

# save to file
W = tf.Variable([[1, 2, 8], [4, 5, 8]], dtype=tf.float32, name='weight')
b = tf.Variable([[1, 2, 8]], dtype=tf.float32, name='biases')

init = tf.global_variables_initializer()
saver = tf.train.Saver()
with tf.Session() as sess:
        sess.run(init)
        save_path = saver.save(sess, "./my_net/save_net.ckpt")
        print ("save to path:", save_path)

模型載入:

用同一個Saver物件來恢復變數(需要把模型的結構重新定義一遍)。注意,當你從檔案恢復變數時,不需要對它進行初始化,否則會報錯。

import tensorflow as tf
import numpy as np

W = tf.Variable(np.arange(6).reshape((2, 3)), dtype=tf.float32, name='weight')
b = tf.Variable(np.arange(3).reshape((1, 3)), dtype=tf.float32, name='biases')

saver = tf.train.Saver()
with tf.Session() as sess:
    saver.restore(sess, "./my_net/save_net.ckpt")
    print ("weights:", sess.run(W))
    print ("biases:", sess.run(b))

恢復訓練:

import tensorflow as tf
import numpy as np
import os

#輸入資料
x_data = np.linspace(-1,1,300)[:, np.newaxis]
noise = np.random.normal(0,0.05, x_data.shape)
y_data = np.square(x_data)-0.5+noise

#輸入層
xs = tf.placeholder(tf.float32, [None, 1])
ys = tf.placeholder(tf.float32, [None, 1])

#隱層
W1 = tf.Variable(tf.random_normal([1,10]))
b1 = tf.Variable(tf.zeros([1,10])+0.1)
Wx_plus_b1 = tf.matmul(xs,W1) + b1
output1 = tf.nn.relu(Wx_plus_b1)

#輸出層
W2 = tf.Variable(tf.random_normal([10,1]))
b2 = tf.Variable(tf.zeros([1,1])+0.1)
Wx_plus_b2 = tf.matmul(output1,W2) + b2
output2 = Wx_plus_b2

#損失
loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys-output2),reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)

#模型儲存載入工具
saver = tf.train.Saver()

#判斷模型儲存路徑是否存在,不存在就建立
if not os.path.exists('tmp/'):
    os.mkdir('tmp/')

#初始化
sess = tf.Session()
if os.path.exists('tmp/checkpoint'): #判斷模型是否存在
    saver.restore(sess, 'tmp/model.ckpt') #存在就從模型中恢復變數
else:
    init = tf.global_variables_initializer() #不存在就初始化變數
    sess.run(init)

#訓練
for i in range(1000):
    _,loss_value = sess.run([train_step,loss], feed_dict={xs:x_data,ys:y_data})
    if(i%50==0): #每50次儲存一次模型
        save_path = saver.save(sess, 'tmp/model.ckpt') #儲存模型到tmp/model.ckpt,注意一定要有一層資料夾,否則儲存不成功!!!
        print("模型儲存:%s 當前訓練損失:%s"%(save_path, loss_value))

這種方法不方便的在於,在載入模型的時候,必須把模型的結構重新定義一遍,然後載入對應名字的變數的值。但是很多時候我們都更希望能夠讀取一個檔案然後就直接使用模型,而不是還要把模型重新定義一遍。所以就需要使用另一種方法。

方式二:

不需重新定義網路結構的方法: tf.train.import_meta_graph

import_meta_graph(
    meta_graph_or_file,
    clear_devices=False,
    import_scope=None,
    **kwargs
)

這個方法可以從檔案中將儲存的graph的所有節點載入到當前的default graph中,並返回一個saver。也就是說,我們在儲存的時候,除了將變數的值儲存下來,其實還有將對應graph中的各種節點儲存下來,所以模型的結構也同樣被儲存下來了。

比如我們想要儲存計算最後預測結果的 y,則應該在訓練階段將它新增到collection中。(這個是不是不需要手動新增)

模型儲存:

### 定義模型
input_x = tf.placeholder(tf.float32, shape=(None, in_dim), name='input_x')
input_y = tf.placeholder(tf.float32, shape=(None, out_dim), name='input_y')

w1 = tf.Variable(tf.truncated_normal([in_dim, h1_dim], stddev=0.1), name='w1')
b1 = tf.Variable(tf.zeros([h1_dim]), name='b1')
w2 = tf.Variable(tf.zeros([h1_dim, out_dim]), name='w2')
b2 = tf.Variable(tf.zeros([out_dim]), name='b2')
keep_prob = tf.placeholder(tf.float32, name='keep_prob')
hidden1 = tf.nn.relu(tf.matmul(self.input_x, w1) + b1)
hidden1_drop = tf.nn.dropout(hidden1, self.keep_prob)
### 定義預測目標
y = tf.nn.softmax(tf.matmul(hidden1_drop, w2) + b2)
# 建立saver
saver = tf.train.Saver(...variables...)
# 假如需要儲存y,以便在預測時使用
tf.add_to_collection('pred_network', y)
sess = tf.Session()
for step in xrange(1000000):
    sess.run(train_op)
    if step % 1000 == 0:
        # 儲存checkpoint, 同時也預設匯出一個meta_graph
        # graph名為'my-model-{global_step}.meta'.
        saver.save(sess, 'my-model', global_step=step)

模型載入:

checkpoint_file=tf.train.latest_checkpoint(checkpoint_directory)
graph=tf.Graph()

  with graph.as_default():
    session_conf = tf.ConfigProto(allow_safe_placement=True, log_device_placement =False)
    sess = tf.Session(config = session_conf)
    with sess.as_default():
      saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
      saver.restore(sess,checkpoint_file)

      # tf.get_collection() 返回一個list. 但是這裡只要第一個引數即可
      y = tf.get_collection('pred_network')[0]

      # 因為y中有placeholder,所以sess.run(y)的時候還需要用實際待預測的樣本以及相應的引數來填充這些placeholder,而這些需要通過graph的get_operation_by_name方法來獲取。
      input_x = graph.get_operation_by_name('input_x').outputs[0]
      keep_prob = graph.get_operation_by_name('keep_prob').outputs[0]

      # 使用y進行預測  
      sess.run(y, feed_dict={input_x:....,  keep_prob:1.0})

這裡有兩點需要注意的: 
一、 saver.restore()時填的檔名,因為在saver.save的時候,每個checkpoint會儲存三個檔案,如

my-model-10000.meta     元模型檔案,儲存圖的結構

my-model-10000.index

my-model-10000.data-00000-of-00001    權重檔案

import_meta_graph時填的是meta檔名。權值都儲存在my-model-10000.data-00000-of-00001這個檔案中,但是如果在restore方法中填這個檔名,就會報錯,應該填的是字首,這個字首可以使用這個方法tf.train.latest_checkpoint(checkpoint_dir)獲取。

二、模型的y中有用到placeholder,在sess.run()的時候肯定要feed對應的資料,因此還要根據具體placeholder的名字,從graph中使用get_operation_by_name方法獲取。