1. 程式人生 > >tensorflow學習筆記——模型持久化的原理,將CKPT轉為pb檔案,使用pb模型預測

tensorflow學習筆記——模型持久化的原理,將CKPT轉為pb檔案,使用pb模型預測

  由題目就可以看出,本節內容分為三部分,第一部分就是如何將訓練好的模型持久化,並學習模型持久化的原理,第二部分就是如何將CKPT轉化為pb檔案,第三部分就是如何使用pb模型進行預測。

一,模型持久化

  為了讓訓練得到的模型儲存下來方便下次直接呼叫,我們需要將訓練得到的神經網路模型持久化。下面學習通過TensorFlow程式來持久化一個訓練好的模型,並從持久化之後的模型檔案中還原被儲存的模型,然後學習TensorFlow持久化的工作原理和持久化之後檔案中的資料格式。

1,持久化程式碼實現

  TensorFlow提供了一個非常簡單的API來儲存和還原一個神經網路模型。這個API就是 tf.train.Saver 類。使用 tf.train.saver()  儲存模型時會產生多個檔案,會把計算圖的結構和圖上引數取值分成了不同的檔案儲存。這種方式是在TensorFlow中是最常用的儲存方式。

  下面程式碼給出了儲存TensorFlow計算圖的方法:

#_*_coding:utf-8_*_
import tensorflow as tf
import os

# 宣告兩個變數並計算他們的和
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name='v1')
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name='v2')
result = v1 + v2

init_op = tf.global_variables_initializer()
# 宣告 tf.train.Saver類用於儲存模型
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(init_op)
    # 將模型儲存到model.ckpt檔案中
    model_path = 'model/model.ckpt'
    saver.save(sess, model_path)

  上面的程式碼實現了持久化一個簡單的TensorFlow模型的功能。在這段程式碼中,通過saver.save 函式將TensorFlow模型儲存到了 model/model.path 檔案中。TensorFlow模型一般會儲存在後綴為 .ckpt 的檔案中,雖然上面的程式只指定了一個檔案路徑,但是這個檔案目錄下面會出現三個檔案。這是因為TensorFlow會將計算圖的結構和圖上引數取值分開儲存。

  執行上面程式碼,我們檢視model檔案裡面的檔案如下:

   下面解釋一下檔案分別是幹什麼的:

  • checkpoint檔案是檢查點檔案,檔案儲存了一個目錄下所有模型檔案列表。
  • model.ckpt.data檔案儲存了TensorFlow程式中每一個變數的取值
  • model.ckpt.index檔案則儲存了TensorFlow程式中變數的索引
  • model.ckpt.meta檔案則儲存了TensorFlow計算圖的結構(可以簡單理解為神經網路的網路結構),該檔案可以被 tf.train.import_meta_graph 載入到當前預設的圖來使用。

      下面程式碼給出載入這個模型的方法:

#_*_coding:utf-8_*_
import tensorflow as tf

#使用和儲存模型程式碼中一樣的方式來宣告變數
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name='v1')
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name='v2')
result = v1 + v2

saver = tf.train.Saver()

with tf.Session() as sess:
    # 載入已經儲存的模型,並通過已經儲存的模型中的變數的值來計算加法
    model_path = 'model/model.ckpt'
    saver.restore(sess, model_path)
    print(sess.run(result))

# 結果如下:[3.]

  這段載入模型的程式碼基本上和儲存模型的程式碼是一樣的。在載入模型的程式中也是先定義了TensorFlow計算圖上所有運算,並聲明瞭一個 tf.train.Saver類。兩段程式碼唯一不同的是,在載入模型的程式碼中沒有執行變數的初始化過程,而是將變數的值通過已經儲存的模型加載出來。如果不希望重複定義圖上的運算,也可以直接載入已經持久化的圖,以下程式碼給出一個樣例:

import tensorflow as tf

# 直接載入持久化的圖
model_path = 'model/model.ckpt'
model_path1 = 'model/model.ckpt.meta'
saver = tf.train.import_meta_graph(model_path1)

with tf.Session() as sess:
    saver.restore(sess, model_path)
    # 通過張量的的名稱來獲取張量
    print(sess.run(tf.get_default_graph().get_tensor_by_name('add:0')))

# 結果如下:[3.]

  其上面給出的程式中,預設儲存和載入了TensorFlow計算圖上定義的所有變數。但是有時可能只需要儲存或者載入部分變數。比如,可能有一個之前訓練好的五層神經網路模型,現在想嘗試一個六層神經網路,那麼可以將前面五層神經網路中的引數直接載入到新的模型,而僅僅將最後一層神經網路重新訓練。

  為了儲存或者載入部分變數,在宣告 tf.train.Saver 類時可以提供一個列表來指定需要儲存或者載入的變數。比如在載入模型的程式碼中使用 saver = tf.train.Saver([v1]) 命令來構建 tf.train.Saver 類,那麼只有變數 v1 會被載入進來。如果執行修改後只加載了 v1 的程式碼會得到變數未初始化的錯誤:

tensorflow.python.framework.errors.FailedPreconditionError:Attempting to 
use uninitialized value v2

  因為 v2 沒有被載入,所以v2在執行初始化之前是沒有值的。除了可以選取需要被載入的變數,tf.train.Saver 類也支援在儲存或者載入時給變數重新命名。

  下面給出一個簡單的樣例程式說明變數重新命名是如何被使用的。

import tensorflow as tf

# 這裡宣告的變數名稱和已經儲存的模型中變數的的名稱不同
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name='other-v1')
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name='other-v2')

# 如果直接使用 tf.train.Saver() 來載入模型會報變數找不到的錯誤,下面顯示了報錯資訊
# tensorflow.python.framework.errors.FailedPreconditionError:Tensor name 'other-v2'
# not found in checkpoint file model/model.ckpt

# 使用一個字典來重新命名變數就可以載入原來的模型了
# 這個字典指定了原來名稱為 v1 的變數現在載入到變數 v1中(名稱為 other-v1)
# 名稱為v2 的變數載入到變數 v2中(名稱為 other-v2)
saver = tf.train.Saver({'v1': v1, 'v2': v2})

  在這個程式中,對變數 v1 和 v2 的名稱進行了修改。如果直接通過 tf.train.Saver 預設的建構函式來載入儲存的模型,那麼程式會報變數找不到的錯誤,因為儲存時候的變數名稱和載入時變數的名稱不一致。為了解決這個問題,Tensorflow 可以通過字典(dictionary)將模型儲存時的變數名和需要載入的變數聯絡起來。這樣做的主要目的之一就是方便使用變數的滑動平均值。在之前介紹了使用變數的滑動平均值可以讓神經網路模型更加健壯(robust)。在TensorFlow中,每一個變數的滑動平均值是通過影子變數維護的,所以要獲取變數的滑動平均值實際上就是獲取這個影子變數的取值。如果在載入模型時將影子變數對映到變數本身,那麼在使用訓練好的模型時就不需要再呼叫函式來獲取變數的滑動平均值了。這樣就大大方便了滑動平均模型的時域。下面程式碼給出了一個儲存滑動平均模型的樣例:

import tensorflow as tf

v = tf.Variable(0, dtype=tf.float32, name='v')
# 在沒有申明滑動平均模型時只有一個變數 v,所以下面語句只會輸出 v:0
for variables in tf.global_variables():
    print(variables.name)

ema = tf.train.ExponentialMovingAverage(0.99)
maintain_averages_op = ema.apply(tf.global_variables())
# 在申明滑動平均模型之後,TensorFlow會自動生成一個影子變數 v/ExponentialMovingAverage
# 於是下面的語句會輸出 v:0 和 v/ExponentialMovingAverage:0
for variables in tf.global_variables():
    print(variables.name)

saver = tf.train.Saver()
with tf.Session() as sess:
    init_op = tf.global_variables_initializer()
    sess.run(init_op)

    sess.run(tf.assign(v, 10))
    sess.run(maintain_averages_op)
    # 儲存時,TensorFlow會將v:0 和 v/ExponentialMovingAverage:0 兩個變數都儲存下來
    saver.save(sess, 'model/modeltest.ckpt')
    print(sess.run([v, ema.average(v)]))
    # 輸出結果 [10.0, 0.099999905]

  下面程式碼給出瞭如何通過變數重新命名直接讀取變數的滑動平均值。從下面程式的輸出可以看出,讀取的變數 v 的值實際上是上面程式碼中變數 v 的滑動平均值。通過這個方法,就可以使用完全一樣的程式碼來計算滑動平均模型前向傳播的結果:

v = tf.Variable(0, dtype=tf.float32, name='v')
# 通過變數重新命名將原來變數v的滑動平均值直接賦值給 V
saver = tf.train.Saver({'v/ExponentialMovingAverage': v})
with tf.Session() as sess:
    saver.restore(sess, 'model/modeltest.ckpt')
    print(sess.run(v))
    # 輸出 0.099999905  這個值就是原來模型中變數 v 的滑動平均值

  為了方便載入時重新命名滑動平均變數,tf.train.ExponentialMovingAverage 類提供了 variables_tp_restore 函式來生成 tf.train.Saver類所需要的變數重新命名字典,一下程式碼給出了 variables_to_restore 函式的使用樣例:

v = tf.Variable(0, dtype=tf.float32, name='v')
ema = tf.train.ExponentialMovingAverage(0.99)

# 通過使用 variables_to_restore 函式可以直接生成上面程式碼中提供的字典
# {'v/ExponentialMovingAverage': v}
# 下面程式碼會輸出 {'v/ExponentialMovingAverage': <tf.Variable 'v:0' shape=() dtype=float32_ref>}
print(ema.variables_to_restore())

saver = tf.train.Saver(ema.variables_to_restore())
with tf.Session() as sess:
    saver.restore(sess, 'model/modeltest.ckpt')
    print(sess.run(v))
    # 輸出 0.099999905  即原來模型中變數 v 的滑動平均值

  使用 tf.train.Saver 會儲存進行TensorFlow程式所需要的全部資訊,然後有時並不需要某些資訊。比如在測試或者離線預測時,只需要知道如何從神經網路的輸出層經過前向傳播計算得到輸出層即可,而不需要類似於變數初始化,模型儲存等輔助接點的資訊。而且,將變數取值和計算圖結構分成不同的檔案儲存有時候也不方便,於是TensorFlow提供了 convert_variables_to_constants 函式,通過這個函式可以將計算圖中的變數及其取值通過常量的方式儲存,這樣整個TensorFlow計算圖可以統一存放在一個檔案中,該方法可以固化模型結構,而且儲存的模型可以移植到Android平臺。

convert_variables_to_constants固化模型結構

  下面給出一個樣例:

import tensorflow as tf
from tensorflow.python.framework import graph_util

v1 = tf.Variable(tf.constant(1.0, shape=[1]), name='v1')
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name='v2')
result = v1 + v2

init_op = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init_op)
    # 匯出當前計算圖的GraphDef部分,只需要這一步就可以完成從輸入層到輸出層的過程
    graph_def = tf.get_default_graph().as_graph_def()

    # 將圖中的變數及其取值轉化為常量,同時將圖中不必要的節點去掉
    # 在下面,最後一個引數['add']給出了需要儲存的節點名稱
    # add節點是上面定義的兩個變數相加的操作
    # 注意這裡給出的是計算節點的的名稱,所以沒有後面的 :0
    output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, (['add']))
    # 將匯出的模型存入檔案
    with tf.gfile.GFile('model/combined_model.pb', 'wb') as f:
        f.write(output_graph_def.SerializeToString())

  通過下面的程式可以直接計算定義加法運算的結果,當只需要得到計算圖中某個節點的取值時,這提供了一個更加方便的方法,以後將使用這種方法來使用訓練好的模型完成遷移學習。

import tensorflow as tf
from tensorflow.python.platform import gfile

with tf.Session() as sess:
    model_filename = 'model/combined_model.pb'
    # 讀取儲存的模型檔案,並將檔案解析成對應的GraphDef Protocol Buffer
    with gfile.FastGFile(model_filename, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    # 將graph_def 中儲存的圖載入到當前的圖中,
    # return_elements = ['add: 0'] 給出了返回的張量的名稱
    # 在儲存的時候給出的是計算節點的名稱,所以為add
    # 在載入的時候給出的張量的名稱,所以是 add:0
    result = tf.import_graph_def(graph_def, return_elements=['add: 0'])
    print(sess.run(result))
    # 輸出 [array([3.], dtype=float32)]

 

 2,持久化原理及資料格式

  上面學習了當呼叫 saver.save 函式時,TensorFlow程式會自動生成四個檔案。TensorFlow模型的持久化就是通過這個四個檔案完成的。這裡我們詳細學習一下這個三個檔案中儲存的內容以及資料格式。

  TensorFlow是一個通過圖的形式來表述計算的程式設計系統,TensorFlow程式中所有計算都會被表達為計算圖上的節點。TensorFlow通過元圖(MetaGraph)來記錄計算圖中節點的資訊以及執行計算圖中節點所需要的元資料。TensorFlow中元圖是由 MetaGraphDef Protocol Buffer 定義的。MetaGraphDef 中的內容就構成了TensorFlow 持久化的第一個檔案,以下程式碼給出了MetaGraphDef型別的定義:

message MetaGraphDef{
    MeatInfoDef meta_info_def = 1;
    GraphDef graph_def = 2;
    SaverDef saver_def = 3;
    map<string,CollectionDef> collection_def = 4;
    map<string,SignatureDef> signature_def = 5;
}

  從上面程式碼中可以看到,元圖中主要記錄了五類資訊,下面結合變數相加樣例的持久化結果,逐一介紹MetaGraphDef型別的每一個屬性中儲存的資訊。儲存 MetaGraphDef 資訊的檔案預設為以 .meta 為字尾名,在上面,檔案 model.ckpt.meta 中儲存的就是元圖的資料。直接執行其樣例得到的是一個二進位制檔案,無法直接檢視。為了方便除錯,TensorFlow提供了 export_meta_graph 函式,這函式支援以json格式匯出 MetaGraphDef Protocol Buffer。下面程式碼展示瞭如何使用這個函式:

import tensorflow as tf

# 定義變數相加的計算
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name='v1')
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name='v2')
result = v1 + v2

saver = tf.train.Saver()
# 通過  export_meta_graph() 函式匯出TensorFlow計算圖的元圖,並儲存為json格式
saver.export_meta_graph('model/model.ckpt.meda.json', as_text=True)

  通過上面給出的程式碼,我們可以將計算圖元圖以json的格式匯出並存儲在 model.ckpt.meda.json 檔案中。下面給出這個檔案的大概內容:

  我們從JSON檔案中可以看到確實是五類資訊。下面結合這JSON檔案的具體內容來學習一下TensorFlow中元圖儲存的資訊。

1,meta_info_def屬性

  meta_info_def 屬性是通過MetaInfoDef定義的。它記錄了TensorFlow計算圖中的元資料以及TensorFlow程式中所有使用到的運算方法的資訊,下面是 MetaInfoDef Protocol Buffer 的定義:

message MetaInfoDef{
    #saver沒有特殊指定,預設屬性都為空。meta_info_def屬性裡只有stripped_op_list屬性不能為空。
    #該屬性不能為空
    string meta_graph_version = 1;
    #該屬性記錄了計算圖中使用到的所有運算方法的資訊,該函式只記錄運算資訊,不記錄計算的次數
    OpList stripped_op_list = 2;
    google.protobuf.Any any_info = 3;
    repeated string tags = 4;
}

  TensorFlow計算圖的元資料包括了計算圖的版本號(meta_graph_version屬性)以及使用者指定的一些標籤(tags屬性)。如果沒有在 saver中特殊指定,那麼這些屬性都預設為空。

  在model.ckpt.meta.json檔案中,meta_info_def 屬性裡只有 stripped_op_list屬性是不為空的。stripped_op_list 屬性記錄了TensorFlow計算圖上使用到的所有運算方法的資訊。注意stripped_op_list 屬性儲存的是 TensorFlow 運算方法的資訊,所以如果某一個運算在TensorFlow計算圖中出現了多次,那麼在 stripped_op_list  也只會出現一次。比如在 model.ckpt.meta.jspm 檔案的 stripped_op_list  屬性只有一個 Variable運算,但是這個運算在程式中被使用了兩次。 

  stripped_op_list 屬性的型別是  OpList。OpList 型別是一個 OpDef型別的列表,以下程式碼給出了 OpDef 型別的定義:

message opDef{
    string name = 1;#定義了運算的名稱
    repeated ArgDef input_arg = 2; #定義了輸入,屬性是列表
    repeated ArgDef output_arg =3; #定義了輸出,屬性是列表
    repeated AttrDef attr = 4;#給出了其他運算的引數資訊
    string summary = 5;
    string description = 6;
    OpDeprecation deprecation = 8;
    bool is_commutative = 18;
    bool is_aggregate = 16
    bool is_stateful = 17;
    bool allows_uninitialized_input = 19;
};

  OpDef 型別中前四個屬性定義了一個運算最核心的資訊。OpDef 中的第一個屬性 name 定義了運算的名稱,這也是一個運算唯一的識別符號。在TensorFlow計算圖元圖的其他屬性中,比如下面要學習的GraphDef屬性,將通過運算名稱來引用不同的運算。OpDef 的第二個和第三個屬性為 input_arg 和 output_arg,他們定義了運算的輸出和輸入。因為輸入輸出都可以有多個,所以這兩個屬性都是列表。第四個屬性Attr給出了其他的運算引數資訊。在JSON檔案中共定義了七個運算,下面將給出比較有代表性的一個運算來輔助說明OpDef 的資料結構。

op {
    name: "Add"
    input_arg{
        name: "x"
        type_attr:"T"
    }
    input_arg{
        name: "y"
        type_attr:"T"
    }
    output_arg{
        name: "z"
        type_attr:"T"
    }
    attr{
        name:"T"
        type:"type"
        allow_values{
            list{
                type:DT_HALF
                type:DT_FLOAT
                ...
            }
        }
    }
}

  上面給出了名稱為Add的運算。這個運算有兩個輸入和一個輸出,輸入輸出屬性都指定了屬性 type_attr,並且這個屬性的值為 T。在OpDef的Attr屬性中,必須要出現名稱(name)為 T的屬性。以上樣例中,這個屬性指定了運算輸入輸出允許的引數型別(allowed_values)。

2,graph_def 屬性

  graph_def 屬性主要記錄了TensorFlow 計算圖上的節點資訊。TensorFlow計算圖的每一個節點對應了TensorFlow程式中一個運算,因為在 meta_info_def 屬性中已經包含了所有運算的具體資訊,所以 graph_def 屬性只關注運算的連線結構。graph_def屬性是通過 GraphDef Protocol Buffer 定義的,graph_def主要包含了一個 NodeDef型別的列表。一下程式碼給出了 graph_def 和NodeDef型別中包含的資訊:

message GraphDef{
    #GraphDef的主要資訊儲存在node屬性中,他記錄了Tensorflow計算圖上所有的節點資訊。
    repeated NodeDef node = 1;
    VersionDef versions = 4; #主要儲存了Tensorflow的版本號
};

message NodeDef{
    #NodeDef型別中有一個名稱屬性name,他是一個節點的唯一識別符號,在程式中,通過節點的名稱來獲得相應的節點。
    string name = 1;

    '''
    op屬性給出了該節點使用的Tensorflow運算方法的名稱。
    通過這個名稱可以在TensorFlow計算圖元圖的meta_info_def屬性中找到該運算的具體資訊。
    '''
    string op = 2;

    '''
    input屬性是一個字串列表,他定義了運算的輸入。每個字串的取值格式為弄的:src_output
    node部分給出節點名稱,src_output表明了這個輸入是指定節點的第幾個輸出。
    src_output=0時可以省略src_output部分
    '''
    repeated string input = 3;

    #制定了處理這個運算的裝置,可以是本地或者遠端的CPU or GPU。屬性為空時自動選擇
    string device = 4;

    #制定了和當前運算有關的配置資訊
    map<string, AttrValue> attr = 5;
};

  GraphDef中的versions屬性比較簡單,它主要儲存了TensorFlow的版本號。和其他屬性類似,NodeDef 型別中有一個名稱屬性 name,它是一個節點的唯一識別符號,在TensorFlow程式中可以通過節點的名稱來獲取響應節點。 NodeDef 型別中 的 device屬性指定了處理這個運算的裝置。執行TensorFlow運算的裝置可以是本地機器的CPU或者GPU,當device屬性為空時,TensorFlow在執行時會自動選取一個最適合的裝置來執行這個運算,最後NodeDef型別中的Attr屬性指定了和當前運算相關的配置資訊。

  下面列舉了 model.ckpt.meta.json 檔案中的一個計算節點來更加具體的瞭解graph_def屬性:

graph def {
    node {
        name: "v1"
        op: "Variable"
        attr {
            key:"_output_shapes"
            value {
                list{ shape { dim { size: 1 } } }
            }
        }
    }
    attr { 
        key :"dtype"
        value {
            type: DT_FLOAT
            }
        }           
        ...
    }
    node {
        name :"add"
        op :"Add"
        input :"v1/read" #read指讀取變數v1的值
        input: "v2/read"
        ...
    }
    node {
        name: "save/control_dependency" #指系統在完成tensorflow模型持久化過程中自動生成一個運算。
        op:"Identity"
        ...
    }
    versions {
        producer :24 #給出了檔案使用時的Tensorflow版本號。
    }
}

  上面給出了 model.ckpt.meta.json檔案中 graph_def 屬性裡面比較有代表性的幾個節點。第一個節點給出的是變數定義的運算。在TensorFlow中變數定義也是一個運算,這個運算的名稱為 v1(name:),運算方法的名稱是Variable(op: "Variable")。定義變數的運算可以有很多個,於是在NodeDef型別的node屬性中可以有多個變數定義的節點。但是定義變數的運算方法只用到了一個,於是在MetaInfoDef型別的 stripped_op_list 屬性中只有一個名稱為Variable 的運算方法。除了制定計算圖中的節點的名稱和運算方法。NodeDef型別中還定義了運算相關的屬性。在節點 v1中,Attr屬性指定了這個變數的維度以及型別。

  給出的第二個節點是代表加法運算的節點。它指定了2個輸入,一個為 v1/read,另一個為 v2/read。其中 v1/read 代表的節點可以讀取變數 v1的值,因為 v1的值是節點 v1/read的第一個輸出,所以後面的:0就可以省略了。v2/read也類似的代表了變數v2的取值。以上樣例檔案中給出的最後一個名稱為 save/control_dependency,該節點是系統在完成TensorFlow模型持久化過程中自動生成的一個運算。在樣例檔案的最後,屬性versions給出了生成 model.ckpt.meta.json 檔案時使用的TensorFlow版本號。

3,saver_def 屬性

  saver_def 屬性中記錄了持久化模型時需要用到的一些引數,比如儲存到檔案的檔名,儲存操作和載入操作的名稱以及儲存頻率,清理歷史記錄等。saver_def 屬性的型別為SaverDef,其定義如下:

message SaverDef {
    string filename_tensor_name = 1;
    string save_tensor_name = 2;
    string restore_op_name = 3;
    int32 max_to_keep = 4;
    bool sharded = 5;
    float keep_checkpoint_every_n_hours = 6;
    enum CheckpointFormatVersion {
        LEGACY = 0;
        V1 = 1;
        V2 = 2;
    }
    CheckpointFormatVersion version = 7;
}

  下面給出了JSON檔案中 saver_def 屬性的內容:

saver_def {
  filename_tensor_name: "save/Const:0"
  save_tensor_name: "save/control_dependency:0"
  restore_op_name: "save/restore_all"
  max_to_keep: 5
  keep_checkpoint_every_n_hours: 10000.0
  version: V2
}

  filename_tensor_name 屬性給出了儲存檔名的張量名稱,這個張量就是節點 save/Const的第一個輸出。save_tensor_name屬性給出了持久化TensorFlow模型的運算所對應的節點名稱。從上面的檔案中可以看出,這個節點就是在 graph_def 屬性中給出的 save/control_dependency節點。和持久化TensorFlow模型運算對應的是載入TensorFlow模型的運算,這個運算的名稱是由 restore_op_name 屬性指定。max_to_keep 屬性和 keep_checkpoint_every_n_hours屬性設定了 tf.train.Saver 類清理之前儲存的模型的策略。比如當 max_to_keep 為5的時候,在第六次呼叫 saver.save 時,第一次儲存的模型就會被自動刪除,通過設定 keep_checkpoint_every_n_hours,每n小時可以在 max_to_keep 的基礎上多儲存一個模型。

4,collection def 屬性

  在TensorFlow的計算圖(tf.Graph)中可以維護不同集合,而維護這些集合的底層實現就是通過collection_def 這個屬性。collection_def 屬性是一個從集合名稱到集合內容的對映,其中集合名稱為字串,而集合內容為 CollectionDef Protocol Buffer。以下程式碼給出了 CollectionDef型別的定義:

message CollectionDef {
    message Nodelist {
    #用於維護計算圖上的節點集合
        repeated string value = 1;
    }

    message BytesList {
    #維護字串或者系列化之後的Procotol Buffer的集合。例如張量是通過Protocol Buffer表示的,而張量的集合是通過BytesList維護的。
        repeated bytes value = 1 ;
    }

    message Int64List {
        repeated int64 value = 1[packed = true];
    }
    message FloatList {
        repeated float value = 1[packed = true] ;
    }
    message AnyList {
        repeated google.protobuf.Any value= 1;
    }
    oneof kind {
        NodeList node_list = 1;
        BytesList bytes_lista = 2;
        Int64List int64_list = 3;
        Floatlist float_list = 4;
        AnyList any_list = 5;
    }
}

  通過上面的定義可以看出,TensorFlow計算圖上的集合主要可以維護四類不同的集合。NodeList用於維護計算圖上節點的集合。BytesList 可以維護字串或者系列化之後 Procotol Buffer的集合。比如張量是通過Procotol Buffer表示的,而張量的集合是通過BytesList維護的,我們將在JSON檔案中看到具體樣例。Int64List用於維護整數集合,FloatList用於維護實數集合。下面給出了JSON檔案中collection_def 屬性的內容:

collection_def {
  key: "trainable_variables"
  value {
    bytes_list {
      value: "\n\004v1:0\022\tv1/Assign\032\tv1/read:0"
      value: "\n\004v2:0\022\tv2/Assign\032\tv2/read:0"
    }
  }
}
collection_def {
  key: "variables"
  value {
    bytes_list {
      value: "\n\004v1:0\022\tv1/Assign\032\tv1/read:0"
      value: "\n\004v2:0\022\tv2/Assign\032\tv2/read:0"
    }
  }
}

  從上面的檔案可以看到樣例程式中維護了兩個集合。一個是所有變數的集合,這個集合的名稱是Variables。另外一個是可訓練變數的集合。名為 trainable_variables。在樣例程式中,這兩個集合中的元素是一樣的,都是變數 v1和 v2,他們是系統自動維護的。

  model.ckpt 檔案中列表的第一行描述了檔案的元資訊,比如在這個檔案中儲存的變數列表,列表剩下的每一行儲存了一個變數的片段。變數片段的資訊是通過SavedSlice Protocol Buffer 定義的。SavedSlice 型別中儲存了變數的名稱,當前片段的資訊以及變數取值。TensorFlow提供了  tf.train.NewCheckpointReader 類來檢視 model.ckpt檔案中儲存的變數資訊,下面程式碼展示瞭如何使用tf.train.NewCheckpointReader 類:

#_*_coding:utf-8_*_
import tensorflow as tf

# tf.train.NewCheckpointReader()  可以讀取 checkpoint檔案中儲存的所有變數
reader = tf.train.NewCheckpointReader('model/model.ckpt')

# 獲取所有變數列表,這是一個從變數名到變數維度的字典
all_variables = reader.get_variable_to_shape_map()
for variable_name in all_variables:
    # variable_name 為變數名稱, all_variables[variable_name]為變數的維度
    print(variable_name, all_variables[variable_name])

#獲取名稱為v1 的變數的取值
print('Value for variable v1 is ', reader.get_tensor('v1'))
'''
v1 [1]     # 變數v1的維度為[1]
v2 [1]     # 變數v2的維度為[1]
Value for variable v1 is  [1.]   # 變數V1的取值為1
'''

  最後一個檔案的名字是固定的,叫checkpoint。這個檔案是 tf.train.Saver類自動生成且自動維護的。在 checkpoint 檔案中維護了由一個 tf.train.Saver類持久化的所有 TensorFlow模型檔案的檔名。當某個儲存的TensorFlow模型檔案被刪除的,這個模型所對應的檔名也會從checkpoint檔案中刪除。checkpoint中內容格式為 CheckpointState Protocol Buffer,下面給出了 CheckpointState 型別的定義。

message CheckpointState {
    string model_checkpoint_path = 1,
    repeated string all_model_checkpoint_paths = 2;
}

  model_checkpoint_path 屬性儲存了最新的TensorFlow模型檔案的檔名。 all_model_checkpoint_paths 屬性列表了當前還沒有被刪除的所有TensorFlow模型檔案的檔名。下面給出了生成的某個checkpoint檔案:

model_checkpoint_path: "modeltest.ckpt"
all_model_checkpoint_paths: "modeltest.ckpt"

 

二,將CKPT轉化為pb格式

  很多時候,我們需要將TensorFlow的模型匯出為單個檔案(同時包含模型結構的定義與權重),方便在其他地方使用(如在Android中部署網路)。利用 tf.train.write_graph() 預設情況下只能匯出了網路的定義(沒有權重),而利用 tf.train.Saver().save() 匯出的檔案 graph_def 與權重時分離的,因此需要採用別的方法。我們知道,graph_def 檔案中沒有包含網路中的 Variable值(通常情況儲存了權重),但是卻包含了constant 值,所以如果我們能把Variable 轉換為 constant,即可達到使用一個檔案同時儲存網路架構與權重的目標。

  (PS:利用tf.train.write_graph() 儲存模型,該方法只是儲存了模型的結構,並不儲存訓練完畢的引數值。)

  TensorFlow 為我們提供了 convert_variables_to_constants() 方法,該方法可以固化模型結構,將計算圖中的變數取值以常量的形式儲存,而且儲存的模型可以移植到Android平臺。

  將CKPT轉換成 PB格式的檔案的過程如下:

  • 1,通過傳入 CKPT模型的路徑得到模型的圖和變數資料
  • 2,通過 import_meta_graph 匯入模型中的圖
  • 3,通過saver.restore 從模型中恢復圖中各個變數的資料
  • 4,通過 graph_util.convert_variables_to_constants 將模型持久化

  下面的CKPT 轉換成 PB格式例子,是之前訓練的GoogleNet InceptionV3模型儲存的ckpt轉pb檔案的例子:

#_*_coding:utf-8_*_
import tensorflow as tf
from tensorflow.python.framework import graph_util
from create_tf_record import *

resize_height = 224  # 指定圖片高度
resize_width = 224   # 指定圖片寬度

def freeze_graph(input_checkpoint, output_graph):
    '''

    :param input_checkpoint:
    :param output_graph:  PB 模型儲存路徑
    :return:
    '''
    # 檢查目錄下ckpt檔案狀態是否可用
    # checkpoint = tf.train.get_checkpoint_state(model_folder)
    # 得ckpt檔案路徑
    # input_checkpoint = checkpoint.model_checkpoint_path

    # 指定輸出的節點名稱,該節點名稱必須是元模型中存在的節點
    output_node_names = "InceptionV3/Logits/SpatialSqueeze"
    saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
    graph = tf.get_default_graph()  # 獲得預設的圖
    input_graph_def = graph.as_graph_def()  # 返回一個序列化的圖代表當前的圖

    with tf.Session() as sess:
        saver.restore(sess, input_checkpoint)  # 恢復圖並得到資料
        # 模型持久化,將變數值固定
        output_graph_def = graph_util.convert_variables_to_constants(
            sess=sess,
            # 等於:sess.graph_def
            input_graph_def=input_graph_def,
            # 如果有多個輸出節點,以逗號隔開
            output_node_names=output_node_names.split(","))

        # 儲存模型
        with tf.gfile.GFile(output_graph, "wb") as f:
            f.write(output_graph_def.SerializeToString())  # 序列化輸出
        # 得到當前圖有幾個操作節點
        print("%d ops in the final graph." % len(output_graph_def.node))

        # for op in graph.get_operations():
        #     print(op.name, op.values())

  說明

  • 1,函式 freeze_graph中,最重要的就是要確定“指定輸出的節點名稱”,這個節點名稱必須是原模型中存在的節點,對於 freeze 操作,我們需要定義輸出節點的名字。因為網路其實是比較複雜的,定義了輸出節點的名字,那麼freeze操作的時候就只把輸出該節點所需要的子圖都固化下來,其他無關的就捨棄掉。因為我們 freeze 模型的目的是接下來做預測,所以 output_node_names 一般是網路模型最後一層輸出的節點名稱,或者說我們預測的目標。
  • 2,在儲存的時候,通過 convert_variables_to_constants 函式來指定需要固化的節點名稱,對於下面的程式碼,需要固化的節點只有一個:output_node_names。注意節點名稱與張量名稱的區別。比如:“input:0  是張量的名稱”,而“input” 表示的是節點的名稱。
  • 3,原始碼中通過 graph=tf.get_default_graph() 獲得預設的圖,這個圖就是由 saver=tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True) 恢復的圖,因此必須先執行 tf.train.import_meta_graph,再執行 tf.get_default_graph()。
  • 4,實質上,我們可以直接在恢復的會話 sess 中,獲得預設的網路圖,更簡單的方法,如下:
def freeze_graph(input_checkpoint, output_graph):
    '''
    :param input_checkpoint:
    :param output_graph: PB模型儲存路徑
    :return:
    '''
    # checkpoint = tf.train.get_checkpoint_state(model_folder) #檢查目錄下ckpt檔案狀態是否可用
    # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt檔案路徑

    # 指定輸出的節點名稱,該節點名稱必須是原模型中存在的節點
    output_node_names = "InceptionV3/Logits/SpatialSqueeze"
    saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)

    with tf.Session() as sess:
        # 恢復圖並得到資料
        saver.restore(sess, input_checkpoint)
        # 模型持久化,將變數值固定
        output_graph_def = graph_util.convert_variables_to_constants(  
            sess=sess,
            input_graph_def=sess.graph_def,  # 等於:sess.graph_def
            # 如果有多個輸出節點,以逗號隔開
            output_node_names=output_node_names.split(","))

        # 儲存模型
        with tf.gfile.GFile(output_graph, "wb") as f:
            # 序列化輸出
            f.write(output_graph_def.SerializeToString())
        # 得到當前圖有幾個操作節點
        print("%d ops in the final graph." % len(output_graph_def.node))  

  呼叫方法很簡單,輸入 ckpt 模型路徑,輸出 Pb模型的路徑即可:

# 輸入ckpt模型路徑
input_checkpoint='model/model.ckpt-10000'

# 輸出pb模型的路徑
out_pb_path="model/frozen_model.pb"

# 呼叫freeze_graph將ckpt轉為pb
freeze_graph(input_checkpoint,out_pb_path)  

  注意:在儲存的時候,通過convert_variables_to_constants 函式來指定需要固化的節點名稱,對於上面的程式碼,需要固化的節點只有一個 : output_nideo_names。因此,其他網路模型,也可以通過簡單的修改輸出的節點名稱output_node_names將ckpt轉為pb檔案。

  PS:注意節點名稱,應包含 name_scope 和 variable_scope名稱空間,並用“/”隔開,如“InceptionV3/Logits/SpatialSqueeze”。

2.1 對指定輸出的節點名稱的理解

  如果說我們使用InceptionV3演算法進行訓練,那麼指定輸出的節點名稱如下:

# 指定輸出的節點名稱,該節點名稱必須是原模型中存在的節點
output_node_names = "InceptionV3/Logits/SpatialSqueeze"

  那麼為什麼呢?

  我去查看了InceptionV3的原始碼,首先模型的輸入名字叫做 InceptionV3;

  其次它要的是輸出的節點,我們看InceptionV3演算法的輸出,也就是最後一層的原始碼,部分原始碼如下:

# Final pooling and prediction
with tf.variable_scope('Logits'):
  if global_pool:
    # Global average pooling.
    net = tf.reduce_mean(net, [1, 2], keep_dims=True, name='GlobalPool')
    end_points['global_pool'] = net
  else:
    # Pooling with a fixed kernel size.
    kernel_size = _reduced_kernel_size_for_small_input(net, [8, 8])
    net = slim.avg_pool2d(net, kernel_size, padding='VALID',
                          scope='AvgPool_1a_{}x{}'.format(*kernel_size))
    end_points['AvgPool_1a'] = net
  if not num_classes:
    return net, end_points
  # 1 x 1 x 2048
  net = slim.dropout(net, keep_prob=dropout_keep_prob, scope='Dropout_1b')
  end_points['PreLogits'] = net
  # 2048
  logits = slim.conv2d(net, num_classes, [1, 1], activation_fn=None,
                       normalizer_fn=None, scope='Conv2d_1c_1x1')
  if spatial_squeeze:
    logits = tf.squeeze(logits, [1, 2], name='SpatialSqueeze')
  # 1000
end_points['Logits'] = logits
end_points['Predictions'] = prediction_fn(logits, scope='Predictions')

  我們會發現最後一層的名字為  Logits,輸出的的name = 'SpatialSqueeze'。

  所以我的理解是指定輸出的節點名稱是模型在程式碼中的名稱+最後一層的名稱+輸出節點的名稱。當然這裡只有一個輸出。

  如果不知道網路節點名稱,或者說不想去模型中找節點名稱,那麼我們可以在載入完模型的圖資料之後,可以輸出圖中的節點資訊檢視一下模型的輸入輸出節點:

for op in tf.get_default_graph().get_operations():
    print(op.name, op.values())

  這樣就可以找出輸出節點名稱。那我也在考慮如果只輸出最後節點的名稱是否可行呢?

  我測試了名字改為下面幾種:

    # output_node_names = 'SpatialSqueeze'
    # output_node_names = 'MobilenetV1/SpatialSqueeze'
    output_node_names = 'MobilenetV1/Logits/SpatialSqueeze'

  也就是不新增模型名稱和最後一層的名稱,新增模型名稱不新增最後一層的名稱。均報錯:

AssertionError: MobilenetV1/SpatialSqueeze is not in graph

  所以這裡還是乖乖使用全稱。

那最後輸出的節點名稱到底是什麼呢?怎麼樣可以直接高效的找出呢?

  首先呢,我個人認為,最後輸出的那一層,應該必須把節點名稱命名出來,另外怎麼才能確定我們的圖結構裡有這個節點呢?百度了一下,有人說可以在TensorBoard中查詢到,TensorBoard只能在Linux中使用,在Windows中得到的TensorBoard檢視不了,是亂碼檔案,在Linux中就沒有問題。所以如果你的Windows可以檢視,就不需要去Linux中跑了。

  檢視TensorBoard

tensorboard --logdir = “儲存tensorboard的絕對路徑”

  敲上面的命令,然後就可以得到一個網址,把這個網址複製到瀏覽器上開啟,就可以得到圖結構,然後點開看看,有沒有output這個節點,也可以順便檢視一下自己的網路圖。但是這個方法我沒有嘗試。我繼續百度了一下,哈哈哈哈,查到了下面的方法。

  就是如果可以按照下面四步驟走的話基本就不需要上面那麼麻煩了:

  首先在ckpt模型的輸入輸出張量名稱,然後將ckpt檔案生成pb檔案;再檢視生成的pb檔案的輸入輸出節點,執行pb檔案,進行網路預測。所以這裡關注的重點就是如何檢視ckpt網路的輸入輸出張量名稱和如何檢視生成的pb檔案的輸入輸出節點。

2.2  檢視ckpt網路的輸入輸出張量名稱

  首先我們找到網路訓練後生成的ckpt檔案,執行下面程式碼檢視自己模型的輸入輸出張量名稱(用於儲存pb檔案時保留這兩個節點):

def check_out_pb_name(checkpoint_path):
    reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
    var_to_shape_map = reader.get_variable_to_shape_map()
    for key in var_to_shape_map:
        res = reader.get_tensor(key)
        print('tensor_name: ', key)
        print('a.shape: %s'%[res.shape])

if __name__ == '__main__':
    # 輸入ckpt模型路徑
    checkpoint_path = 'modelsmobilenet/model.ckpt-100000'
    check_out_pb_name(checkpoint_path)

  這裡我繼續使用自己用的mobilenetV1模型,執行後的程式碼部分結果如下:

tensor_name:  MobilenetV1/Conv2d_10_depthwise/BatchNorm/gamma
a.shape: [(512,)]
tensor_name:  MobilenetV1/Conv2d_0/weights
a.shape: [(3, 3, 3, 32)]
tensor_name:  MobilenetV1/Conv2d_9_depthwise/BatchNorm/gamma
a.shape: [(512,)]
tensor_name:  MobilenetV1/Conv2d_11_depthwise/BatchNorm/gamma/Adadelta
a.shape: [(512,)]
tensor_name:  MobilenetV1/Conv2d_11_depthwise/BatchNorm/gamma/Adadelta_1
a.shape: [(512,)]
tensor_name:  MobilenetV1/Conv2d_5_depthwise/depthwise_weights/Adadelta_1
a.shape: [(3, 3, 256, 1)]
tensor_name:  MobilenetV1/Conv2d_10_pointwise/BatchNorm/gamma/Adadelta_1
a.shape: [(512,)]
tensor_name:  MobilenetV1/Conv2d_5_depthwise/depthwise_weights/Adadelta
a.shape: [(3, 3, 256, 1)]
tensor_name:  MobilenetV1/Conv2d_0/BatchNorm/moving_variance
a.shape: [(32,)]
tensor_name:  MobilenetV1/Conv2d_10_pointwise/BatchNorm/gamma/Adadelta
a.shape: [(512,)]
tensor_name:  MobilenetV1/Conv2d_5_depthwise/BatchNorm/beta/Adadelta_1
a.shape: [(256,)]
tensor_name:  MobilenetV1/Conv2d_0/BatchNorm/beta
a.shape: [(32,)]
tensor_name:  MobilenetV1/Conv2d_8_pointwise/BatchNorm/gamma/Adadelta
a.shape: [(512,)]
tensor_name:  MobilenetV1/Conv2d_10_depthwise/BatchNorm/moving_variance
a.shape: [(512,)]
tensor_name:  MobilenetV1/Conv2d_7_depthwise/BatchNorm/beta
a.shape: [(512,)]
tensor_name:  MobilenetV1/Conv2d_0/BatchNorm/beta/Adadelta
a.shape: [(32,)]
tensor_name:  MobilenetV1/Conv2d_0/BatchNorm/gamma
a.shape: [(32,)]

... ...

tensor_name:  MobilenetV1/Conv2d_9_depthwise/BatchNorm/beta/Adadelta
a.shape: [(512,)]
tensor_name:  MobilenetV1/Conv2d_9_depthwise/BatchNorm/gamma/Adadelta
a.shape: [(512,)]
tensor_name:  MobilenetV1/Conv2d_9_depthwise/BatchNorm/gamma/Adadelta_1
a.shape: [(512,)]
tensor_name:  MobilenetV1/Conv2d_9_depthwise/depthwise_weights/Adadelta_1
a.shape: [(3, 3, 512, 1)]
tensor_name:  MobilenetV1/Conv2d_9_pointwise/BatchNorm/beta
a.shape: [(512,)]
tensor_name:  MobilenetV1/Conv2d_9_pointwise/BatchNorm/beta/Adadelta
a.shape: [(512,)]
tensor_name:  MobilenetV1/Conv2d_9_pointwise/BatchNorm/gamma/Adadelta
a.shape: [(512,)]
tensor_name:  MobilenetV1/Conv2d_9_pointwise/BatchNorm/gamma/Adadelta_1
a.shape: [(512,)]
tensor_name:  MobilenetV1/Conv2d_9_pointwise/BatchNorm/moving_mean
a.shape: [(512,)]
tensor_name:  MobilenetV1/Conv2d_9_pointwise/weights
a.shape: [(1, 1, 512, 512)]
tensor_name:  MobilenetV1/Conv2d_9_pointwise/weights/Adadelta
a.shape: [(1, 1, 512, 512)]
tensor_name:  MobilenetV1/Conv2d_9_pointwise/weights/Adadelta_1
a.shape: [(1, 1, 512, 512)]
tensor_name:  MobilenetV1/Logits/Conv2d_1c_1x1/weights
a.shape: [(1, 1, 1024, 51)]
tensor_name:  MobilenetV1/Logits/Conv2d_1c_1x1/weights/Adadelta_1
a.shape: [(1, 1, 1024, 51)]
tensor_name:  MobilenetV1/Logits/Conv2d_1c_1x1/weights/Adadelta
a.shape: [(1, 1, 1024, 51)]

  我的模型是使用TensorFlow官網中標準的MoiblenetV1模型,所以輸入輸出張量比較容易找到,那如果自己的模型比較複雜(或者說是別人重構的模型),那如何找呢?

  那找到模型的定義,然後在模型的最前端打印出輸入張量,在最後打印出輸出張量。

  注意上面雖然最後輸出的張量名稱為:MobilenetV1/Logits/Conv2d_1c_1x1,但是如果我們直接用這個,還是會報錯的,這是為什麼呢?這就得去看模型檔案,上面也有,這裡再貼上一下(還是利用MobilenetV1模型):

with tf.variable_scope(scope, 'MobilenetV1', [inputs], reuse=reuse) as scope:
  with slim.arg_scope([slim.batch_norm, slim.dropout],
                      is_training=is_training):
    net, end_points = mobilenet_v1_base(inputs, scope=scope,
                                        min_depth=min_depth,
                                        depth_multiplier=depth_multiplier,
                                        conv_defs=conv_defs)
    with tf.variable_scope('Logits'):
      if global_pool:
        # Global average pooling.
        net = tf.reduce_mean(net, [1, 2], keep_dims=True, name='global_pool')
        end_points['global_pool'] = net
      else:
        # Pooling with a fixed kernel size.
        kernel_size = _reduced_kernel_size_for_small_input(net, [7, 7])
        net = slim.avg_pool2d(net, kernel_size, padding='VALID',
                              scope='AvgPool_1a')
        end_points['AvgPool_1a'] = net
      if not num_classes:
        return net, end_points
      # 1 x 1 x 1024
      net = slim.dropout(net, keep_prob=dropout_keep_prob, scope='Dropout_1b')
      logits = slim.conv2d(net, num_classes, [1, 1], activation_fn=None,
                           normalizer_fn=None, scope='Conv2d_1c_1x1')
      if spatial_squeeze:
        logits = tf.squeeze(logits, [1, 2], name='SpatialSqueeze')
    end_points['Logits'] = logits
    if prediction_fn:
      end_points['Predictions'] = prediction_fn(logits, scope='Predictions')

  最後這裡,他對Logits變數進行了刪除維度為1的過程。並且將名稱重新命名為SpatialSqueeze,一般如果不進行這一步就沒問題。所以我們如果出問題了,就對模型進行檢視,當然第二個方法是可行的。

 

2.3  檢視生成的pb檔案的輸入輸出節點

  檢視pb檔案的節點,只是為了驗證一下,當然也可以不檢視,直接去上面拿到結果即可,就是輸出節點的名稱。

def create_graph(out_pb_path):
    # 讀取並建立一個圖graph來存放訓練好的模型
    with tf.gfile.FastGFile(out_pb_path, 'rb') as f:
        # 使用tf.GraphDef() 定義一個空的Graph
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        # Imports the graph from graph_def into the current default Graph.
        tf.import_graph_def(graph_def, name='')

def check_pb_out_name(out_pb_path, result_file):
    create_graph(out_pb_path)
    tensor_name_list = [tensor.name for tensor in
                        tf.get_default_graph().as_graph_def().node]
    with open(result_file, 'w+') as f:
        for tensor_name in tensor_name_list:
            f.write(tensor_name+'\n')

  我們執行後,檢視對應的TXT檔案,可以看到,輸入輸出的節點和前面是對應的:

  這樣就解決了這個問題,最後使用pb模型進行預測即可。下面是這兩個查詢輸出節點的完整程式碼:

# _*_coding:utf-8_*_
from tensorflow.python import pywrap_tensorflow
import os
import tensorflow as tf

def check_out_pb_name(checkpoint_path):
    reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
    var_to_shape_map = reader.get_variable_to_shape_map()
    for key in var_to_shape_map:
        res = reader.get_tensor(key)
        print('tensor_name: ', key)
        print('res.shape: %s'%[res.shape])

def create_graph(out_pb_path):
    # 讀取並建立一個圖graph來存放訓練好的模型
    with tf.gfile.FastGFile(out_pb_path, 'rb') as f:
        # 使用tf.GraphDef() 定義一個空的Graph
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        # Imports the graph from graph_def into the current default Graph.
        tf.import_graph_def(graph_def, name='')

def check_pb_out_name(out_pb_path, result_file):
    create_graph(out_pb_path)
    tensor_name_list = [tensor.name for tensor in
                        tf.get_default_graph().as_graph_def().node]
    with open(result_file, 'w+') as f:
        for tensor_name in tensor_name_list:
            f.write(tensor_name+'\n')



if __name__ == '__main__':
    # 輸入ckpt模型路徑
    checkpoint_path = 'modelsmobilenet/model.ckpt-100000'
    check_out_pb_name(checkpoint_path)

    # 輸出pb模型的路徑
    out_pb_path = 'modelmobilenet.pb'
    result_file = 'mobilenet_graph.txt'
    check_pb_out_name(out_pb_path, result_file)

  

三,使用pb模型預測

  下面是pb模型預測的程式碼:

def freeze_graph_test(pb_path, image_path):
    '''
    :param pb_path: pb檔案的路徑
    :param image_path: 測試圖片的路徑
    :return:
    '''
    with tf.Graph().as_default():
        output_graph_def = tf.GraphDef()
        with open(pb_path, "rb") as f:
            output_graph_def.ParseFromString(f.read())
            tf.import_graph_def(output_graph_def, name="")
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())

            # 定義輸入的張量名稱,對應網路結構的輸入張量
            # input:0作為輸入影象,keep_prob:0作為dropout的引數,測試時值為1,is_training:0訓練引數
            input_image_tensor = sess.graph.get_tensor_by_name("input:0")
            input_keep_prob_tensor = sess.graph.get_tensor_by_name("keep_prob:0")
            input_is_training_tensor = sess.graph.get_tensor_by_name("is_training:0")

            # 定義輸出的張量名稱
            output_tensor_name = sess.graph.get_tensor_by_name("InceptionV3/Logits/SpatialSqueeze:0")

            # 讀取測試圖片
            im = read_image(image_path, resize_height, resize_width, normalization=True)
            im = im[np.newaxis, :]
            # 測試讀出來的模型是否正確,注意這裡傳入的是輸出和輸入節點的tensor的名字,不是操作節點的名字
            # out=sess.run("InceptionV3/Logits/SpatialSqueeze:0", feed_dict={'input:0': im,'keep_prob:0':1.0,'is_training:0':False})
            out = sess.run(output_tensor_name, feed_dict={input_image_tensor: im,
                                                          input_keep_prob_tensor: 1.0,
                                                          input_is_training_tensor: False})
            print("out:{}".format(out))
            score = tf.nn.softmax(out, name='pre')
            class_id = tf.argmax(score, 1)
            print("pre class_id:{}".format(sess.run(class_id)))

  

3.1  說明

1,與ckpt預測不同的是,pb檔案已經固化了網路模型結構,因此,即使不知道原訓練模型(train)的原始碼,我們也可以恢復網路圖,並進行預測。恢復模型非常簡單,只需要從讀取的序列化資料中匯入網路結構即可:

tf.import_graph_def(output_graph_def, name="")

2,但是必須知道原網路模型的輸入和輸出的節點名稱(當然了,傳遞資料時,是通過輸入輸出的張量來完成的)。由於InceptionV3模型的輸入有三個節點,因此這裡需要定義輸入的張量名稱,它對應的網路結構的輸入張量:

input_image_tensor = sess.graph.get_tensor_by_name("input:0")

input_keep_prob_tensor = sess.graph.get_tensor_by_name("keep_prob:0")

input_is_training_tensor = sess.graph.get_tensor_by_name("is_training:0")

  以及輸出的張量名稱:

output_tensor_name = sess.graph.get_tensor_by_name(
                                        "InceptionV3/Logits/SpatialSqueeze:0")

3,預測時,需要 feed輸入資料

# 測試讀出來的模型是否正確
# 注意這裡傳入的是輸出和輸入節點的tensor的名字,不是操作節點的名字
# out=sess.run("InceptionV3/Logits/SpatialSqueeze:0", 
                 feed_dict={'input:0': im,'keep_prob:0':1.0,'is_training:0':False})
out=sess.run(output_tensor_name, feed_dict={input_image_tensor: im,
                                            input_keep_prob_tensor:1.0,
                                            input_is_training_tensor:False})

4,其他網路模型預測時,也可以通過修改輸入和輸出的張量的名稱。

(PS:注意張量的名稱,即為:節點名稱+ “:”+“id號”,如"InceptionV3/Logits/SpatialSqueeze:0")

 

  完整的CKPT轉換成PB格式和預測的程式碼如下:

# _*_coding:utf-8_*_
import tensorflow as tf
from tensorflow.python.framework import graph_util
import numpy as np
import cv2

'''
checkpoint檔案是檢查點檔案,檔案儲存了一個目錄下所有模型檔案列表。
model.ckpt.data檔案儲存了TensorFlow程式中每一個變數的取值
model.ckpt.index檔案則儲存了TensorFlow程式中變數的索引
model.ckpt.meta檔案則儲存了TensorFlow計算圖的結構
'''


def freeze_graph(input_checkpoint, output_graph):
    '''
    指定輸出的節點名稱
    將模型檔案和權重檔案整合合併為一個檔案
    :param input_checkpoint:
    :param output_graph: PB模型儲存路徑
    :return:
    '''
    # checkpoint = tf.train.get_checkpoint_state(model_folder)
    # 檢查目錄下的ckpt檔案狀態是否可以用
    # input_checkpoint = checkpoint.model_checkpoint_path  # 得ckpt檔案路徑

    # 指定輸出的節點名稱,該節點名稱必須是原模型中存在的節點
    # PS:注意節點名稱,應包含name_scope 和 variable_scope名稱空間,並用“/”隔開,
    output_node_names = 'MobilenetV1/Logits/SpatialSqueeze'
    # 首先通過下面函式恢復圖
    saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
    # 然後通過下面函式獲得預設的圖
    graph = tf.get_default_graph()
    # 返回一個序列化的圖代表當前的圖
    input_graph_def = graph.as_graph_def()

    with tf.Session() as sess:
        # 載入已經儲存的模型,恢復圖並得到資料
        saver.restore(sess, input_checkpoint)
        # 在儲存的時候,通過下面函式來指定需要固化的節點名稱
        output_graph_def = graph_util.convert_variables_to_constants(
            # 模型持久化,將變數值固定
            sess=sess,
            input_graph_def=input_graph_def,  # 等於:sess.graph_def
            # freeze模型的目的是接下來做預測,
            # 所以 output_node_names一般是網路模型最後一層輸出的節點名稱,或者說我們預測的目標
            output_node_names=output_node_names.split(',')  # 如果有多個輸出節點,以逗號隔開
        )

        with tf.gfile.GFile(output_graph, 'wb') as f:  # 儲存模型
            # 序列化輸出
            f.write(output_graph_def.SerializeToString())
        # # 得到當前圖有幾個操作節點
        print('%d ops in the final graph' % (len(output_graph_def.node)))

        # 這個可以得到各個節點的名稱,如果斷點除錯到輸出結果,看看模型的返回資料
        # 大概就可以猜出輸入輸出的節點名稱
        for op in graph.get_operations():
            print(op.name)
            # print(op.name, op.values())


def read_image(filename, resize_height, resize_width, normalization=False):
    '''
    讀取圖片資料,預設返回的是uint8,[0,255]
    :param filename:
    :param resize_height:
    :param resize_width:
    :param normalization:是否歸一化到[0.,1.0]
    :return: 返回的圖片資料
    '''

    bgr_image = cv2.imread(filename)
    if len(bgr_image.shape) == 2:  # 若是灰度圖則轉為三通道
        print("Warning:gray image", filename)
        bgr_image = cv2.cvtColor(bgr_image, cv2.COLOR_GRAY2BGR)

    rgb_image = cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB)  # 將BGR轉為RGB
    # show_image(filename,rgb_image)
    # rgb_image=Image.open(filename)
    if resize_height > 0 and resize_width > 0:
        rgb_image = cv2.resize(rgb_image, (resize_width, resize_height))
    rgb_image = np.asanyarray(rgb_image)
    if normalization:
        # 不能寫成:rgb_image=rgb_image/255
        rgb_image = rgb_image / 255.0
    # show_image("src resize image",image)
    return rgb_image


def freeze_graph_test(pb_path, image_path):
    '''
    預測pb模型的程式碼
    :param pb_path: pb檔案的路徑
    :param image_path: 測試圖片的路徑
    :return:
    '''
    with tf.Graph().as_default():
        output_graph_def = tf.GraphDef()
        with open(pb_path, 'rb') as f:
            output_graph_def.ParseFromString(f.read())
            # 恢復模型,從讀取的序列化資料中匯入網路結構即可
            tf.import_graph_def(output_graph_def, name="")
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())

            # 定義輸入的張量名稱,對應網路結構的輸入張量
            # input: 0 作為輸入影象,
            # keep_prob:0作為dropout的引數,測試時值為1,
            # is_training: 0 訓練引數
            input_image_tensor = sess.graph.get_tensor_by_name('input:0')
            input_keep_prob_tensor = sess.graph.get_tensor_by_name('keep_prob:0')
            input_is_training_tensor = sess.graph.get_tensor_by_name('is_training:0')

            # 定義輸出的張量名稱:注意為節點名稱 + “:”+id好
            name = 'MobilenetV1/Logits/SpatialSqueeze:0'
            output_tensor_name = sess.graph.get_tensor_by_name(name=name)

            # 讀取測試圖片
            im = read_image(image_path, resize_height, resize_width, normalization=True)
            im = im[np.newaxis, :]
            # 測試讀出來的模型是否正確,注意這裡傳入的時輸出和輸入節點的tensor的名字,不是操作節點的名字
            out = sess.run(output_tensor_name, feed_dict={
                input_image_tensor: im,
                input_keep_prob_tensor: 1.0,
                input_is_training_tensor: False
            })
            print("out:{}".format(out))
            score = tf.nn.softmax(out, name='pre')
            class_id = tf.argmax(score, 1)
            print('Pre class_id:{}'.format(sess.run(class_id)))


if __name__ == '__main__':
    # 輸入ckpt模型路徑
    input_checkpoint = 'modelsmobilenet/model.ckpt-100000'
    # 輸出pb模型的路徑
    out_pb__path = 'modelmobilenet.pb'
    # 指定圖片的高度,寬度
    resize_height, resize_width = 224, 224
    depth = 3

    # 呼叫freeze_graph將ckpt轉pb
    # freeze_graph(input_checkpoint, out_pb__path)

    # 測試pb模型
    image_path = '5.png'
    freeze_graph_test(pb_path=out_pb__path, image_path=image_path)

  結果如下:

out:[[ -6.41409     -7.542293    -4.79263     -0.8360114   -5.9790826
    4.5435553   -0.36825374  -6.4866605   -2.4342375   -0.77123785
   -3.8730755   -2.9347122   -1.2668624   -2.0682898   -4.8219028
   -4.0054555   -4.929347    -4.3350396   -1.3294952   -5.2482243
   -5.6148944   -0.5199025   -2.8043954   -7.536846    -8.050901
   -5.4447656   -6.8323407   -6.221056    -8.040736    -7.3237658
  -10.494858    -9.077686    -6.8210897  -10.038142    -9.5562935
   -3.884094    -4.31197     -7.0326185   -2.3761833   -9.571469
    1.0321844   -9.319367    -5.5040984   -4.881267    -6.99698
   -9.591501    -8.059127    -7.494555   -10.593867    -6.862433
   -4.373736  ]]
Pre class_id:[5]

  我將測試圖片命名為5,就是與結果相對應,結果一致。表明使用pb預測出來了,並且預測正確。

  這裡解釋一下,我是使用MobileNetV1模型進行訓練一個51個分類的資料,而拿到的第6個類的資料進行測試(我的標籤是從0開始的),這裡測試正確。

 

 

   此文是自己的學習筆記總結,學習於《TensorFlow深度學習框架》,俗話說,好記性不如爛筆頭,寫寫總是好的,所以若侵權,請聯絡我,謝謝。

  其實網上有很多ckpt轉pb的文章,大多數來自下面的部落格,我這裡