1. 程式人生 > >Tensorflow將模型匯出為一個檔案及介面設定

Tensorflow將模型匯出為一個檔案及介面設定

在上一篇文章中《Tensorflow載入預訓練模型和儲存模型》,我們學習到如何使用預訓練的模型。但注意到,在上一篇文章中使用預訓練模型,必須至少的要4個檔案:

checkpoint
MyModel.meta
MyModel.data-00000-of-00001
MyModel.index

這很不便於我們的使用。有沒有辦法匯出為一個pb檔案,然後直接使用呢?答案是肯定的。在文章《Tensorflow載入預訓練模型和儲存模型》中提到,meta檔案儲存圖結構,weights等引數儲存在data檔案中。也就是說,圖和引數資料時分開儲存的。說的更直白一點,就是meta檔案中沒有weights等資料。但是,值得注意的是,meta檔案會儲存常量。

我們只需將data檔案中的引數轉為meta檔案中的常量即可!

1 模型匯出為一個檔案

1.1 有程式碼並且從頭開始訓練

Tensorflow提供了工具函式tf.graph_util.convert_variables_to_constants()用於將變數轉為常量。看看官網的描述:

if you have a trained graph containing Variable ops, it can be convenient to convert them all to Const ops holding the same values. This makes it possible to describe the network fully with a single GraphDef file, and allows the removal of a lot of ops related to loading and saving the variables.

我們繼續通過一個簡單例子開始:

import tensorflow as tf

w1 = tf.Variable(20.0, name="w1")
w2 = tf.Variable(30.0, name="w2")
b1= tf.Variable(2.0,name="bias")
w3 = tf.add(w1,w2)

#記住要定義name,後面需要用到
out = tf.multiply(w3,b1,name="out")

# 轉換Variable為constant,並將網路寫入到檔案
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    # 這裡需要填入輸出tensor的名字
graph = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, ["out"]) tf.train.write_graph(graph, '.', './checkpoint_dir/graph.pb', as_text=False)

執行可以看到如下日誌:

Converted 3 variables to const ops.

可以看到通過tf.graph_util.convert_variables_to_constants()函式將變數轉為了常量,並存儲在graph.pb檔案中,接下來看看如何使用這個模型。

import tensorflow as tf
with tf.Session() as sess:
    with open('./checkpoint_dir/graph.pb', 'rb') as graph:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(graph.read())
        output = tf.import_graph_def(graph_def, return_elements=['out:0'])
        print(sess.run(output))

執行結果如下:

[100.0]

回到tf.graph_util.convert_variables_to_constants()函式,可以看到,需要傳入Session物件和圖,這都可以理解。看看第三個引數["out"],它是指定這個模型的輸出Tensor

1.2 有程式碼和模型,但是不想重新訓練模型

有模型原始碼時,在匯出模型時就可以通過tf.graph_util.convert_variables_to_constants()函式來將變數轉為常量儲存到圖檔案中。但是很多時候,我們拿到的是別人的checkpoint檔案,即meta、index、data等檔案。這種情況下,需要將data檔案裡面變數轉為常量儲存到meta檔案中。思路也很簡單,先將checkpoint檔案載入,再重新儲存一次即可。

假設訓練和儲存模型程式碼如下:

import tensorflow as tf

w1 = tf.Variable(20.0, name="w1")
w2 = tf.Variable(30.0, name="w2")
b1= tf.Variable(2.0,name="bias")
w3 = tf.add(w1,w2)

#記住要定義name,後面需要用到
out = tf.multiply(w3,b1,name="out")

# 轉換Variable為constant,並將網路寫入到檔案
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    # 這裡需要填入輸出tensor的名字
    saver.save(sess, './checkpoint_dir/MyModel', global_step=1000)

此時,模型檔案如下:

checkpoint
MyModel-1000.data-00000-of-00001
MyModel-1000.index
MyModel-1000.meta

如果我們只有以上4個模型檔案,但是可以看到訓練原始碼。那麼,將這4個檔案匯出為一個pb檔案方法如下:

import tensorflow as tf
with tf.Session() as sess:

    #初始化變數
    sess.run(tf.global_variables_initializer())

    #獲取最新的checkpoint,其實就是解析了checkpoint檔案
    latest_ckpt = tf.train.latest_checkpoint("./checkpoint_dir")

    #載入圖
    restore_saver = tf.train.import_meta_graph('./checkpoint_dir/MyModel-1000.meta')

    #恢復圖,即將weights等引數加入圖對應位置中
    restore_saver.restore(sess, latest_ckpt)

    #將圖中的變數轉為常量
    output_graph_def = tf.graph_util.convert_variables_to_constants(
        sess, sess.graph_def , ["out"])
    #將新的圖儲存到"/pretrained/graph.pb"檔案中
    tf.train.write_graph(output_graph_def, 'pretrained', "graph.pb", as_text=False)

執行後,會有如下日誌:

Converted 3 variables to const ops.

接下來就是使用,使用方法跟前面一致:

import tensorflow as tf
with tf.Session() as sess:
    with open('./pretrained/graph.pb', 'rb') as graph:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(graph.read())
        output = tf.import_graph_def(graph_def, return_elements=['out:0'])
        print(sess.run(output))

列印資訊如下:

[100.0]

2 模型介面設定

我們注意到,前面只是簡單的獲取一個輸出介面,但是很明顯,我們使用的時候,不可能只有一個輸出,還需要有輸入,接下來我們看看,如何設定輸入和輸出。同樣我們分為有程式碼並且從頭開始訓練,和有程式碼和模型,但是不想重新訓練模型兩種情況。

2.1 有程式碼並且從頭開始訓練

相比1.1中的程式碼略作修改即可,第6行程式碼處做了修改:

import tensorflow as tf

w1 = tf.Variable(20.0, name="w1")
w2 = tf.Variable(30.0, name="w2")

#這裡將b1改為placeholder,讓使用者輸入,而不是寫死
#b1= tf.Variable(2.0,name="bias")
b1= tf.placeholder(tf.float32, name='bias')

w3 = tf.add(w1,w2)

#記住要定義name,後面需要用到
out = tf.multiply(w3,b1,name="out")

# 轉換Variable為constant,並將網路寫入到檔案
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    # 這裡需要填入輸出tensor的名字
    graph = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, ["out"])
    tf.train.write_graph(graph, '.', './checkpoint_dir/graph.pb', as_text=False)

日誌如下:

Converted 2 variables to const ops.

接下來看看如何使用:

import tensorflow as tf
with tf.Session() as sess:
    with open('./checkpoint_dir/graph.pb', 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        output = tf.import_graph_def(graph_def, input_map={'bias:0':4.}, return_elements=['out:0'])
        print(sess.run(output))

列印資訊如下:

[200.0]

也就是說,在設定輸入時,首先將需要輸入的資料作為placeholdler,然後在匯入圖tf.import_graph_def()時,通過引數input_map={}來指定輸入。輸出通過return_elements=[]直接引用tensor的name即可。

2.2 有程式碼和模型,但是不想重新訓練模型

在有程式碼和模型,但是不想重新訓練模型情況下,意味著我們不能直接修改匯出模型的程式碼。但是我們可以通過graph.get_tensor_by_name()函式取得圖中的某些中間結果,然後再加入一些邏輯。其實這種情況在上一篇文章已經講了。可以參考上一篇檔案解決,相比“有程式碼並且從頭開始訓練”情況侷限比較大,大部分情況只能是獲取模型的一些中間結果,但是也滿足我們大多數情況使用了。