1. 程式人生 > >《TensorFlow:實戰Google深度學習框架》——5.4 模型持久化(模型儲存、模型載入)

《TensorFlow:實戰Google深度學習框架》——5.4 模型持久化(模型儲存、模型載入)

目錄

1、持久化程式碼實現

2、載入儲存的TensorFlow模型

3、載入部分變數

4、載入變數時重新命名


1、持久化程式碼實現

TensorFlow提供了一個非常簡單的API來儲存和還原一個神經網路模型。這個API就是tf.train.Saver類。一下程式碼給出了儲存TensorFlow計算圖的方法。

#!/usr/bin/env python
# -*- coding:utf-8 -*-

import tensorflow as tf

# 儲存計算兩個變數和的模型
v1 = tf.Variable(tf.random_normal([1], stddev=1, seed=1))
v2 = tf.Variable(tf.random_normal([1], stddev=1, seed=1))
result = v1 + v2

init_op = tf.global_variables_initializer()
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(init_op)
    # saver.sabe函式儲存到“Saved_model/model.ckpt”
    saver.save(sess, "Saved_model/model.ckpt")

解析:

  • 在這段程式碼中,通過saver.save 函式將TensorFlow模型儲存到了“Saved_model/model.ckpt”檔案中。TensorFlow模型一般會存在後綴為.ckpt的檔案中 。
  • 雖然以上程式只指定了 一個檔案路徑,但是在這個檔案目錄下會出現三個檔案:
  1. 第一個檔案為model.ckpt.meta,它儲存了 TensorFlow計算圖的結構
  2. 第二個檔案為model.ckpt,這個檔案中儲存了TensorFlow 程式中每一個變數的取值。
  3. 第三個檔案為checkpoint檔案,這個檔案中儲存了一個目錄下所有的模型檔案列表

2、載入儲存的TensorFlow模型

以下程式碼中給出了載入這個已經儲存的TensorFlow模型的方法 

#!/usr/bin/env python
# -*- coding:utf-8 -*-

import tensorflow as tf

# 儲存計算兩個變數和的模型
v1 = tf.Variable(tf.random_normal([1], stddev=1, seed=1))
v2 = tf.Variable(tf.random_normal([1], stddev=1, seed=1))
result = v1 + v2

saver = tf.train.Saver()

# 載入儲存的模型,載入全部模型
with tf.Session() as sess:
    saver.restore(sess, "Saved_model/model.ckpt")
    print(sess.run(result))

上述程式碼輸出為:

解析:

這段載入模型的程式碼基本上和儲存模型的程式碼是一樣的。在載入模型的程式中也是先定義了TensorFlow計算圖上的所有運算,並聲明瞭 一個tf.train.Saver類。兩段程式碼唯一不同的是,在載入模型的程式碼中沒有執行變數的初始化過程,而是將變數的值通過己經儲存的模型載入進來。


如果不希望重複定義圖上的運算,也可以直接載入已經持久化的圖。一下程式碼給出一個樣例:

import tensorflow as tf

# 載入持久化的圖
saver = tf.train.import_meta_graph("Saved_model/model.ckpt.meta")

with tf.Session() as sess:
    saver.restore(sess,"Saved_model/model.ckpt")
    # 通過張量的名稱來獲取張量
    print(sess.run(tf.get_default_graph().get_tensor_by_name("add:0")))

這段程式碼與上述程式碼達到的效果相同。是兩種方式載入模型。


3、載入部分變數

為了儲存或者載入部分變數,在宣告 tf.train.Saver 類時可以提供一個列表來指定需要儲存或者載入的變數。比如在載入模型的程式碼中使用saver = tf.train. Saver([v1])命令來構建tf.train.Saver類,那麼只有變數v1會被載入進來 。如果執行修改後只加載了v1的程式碼會得到變革未初始化的錯誤:

tensorflow.python.framework.errors.FailedPreconditionError:Attempting touse uninitialized value v2

4、載入變數時重新命名

tf.train.Saver類也支援在儲存或者載入時給變數重新命名。下面給出了一個簡單的樣例程式說明變數重新命名是如何被使用的。

#!/usr/bin/env python
# -*- coding:utf-8 -*-

import tensorflow as tf

# tf.reset_default_graph()

# 宣告變數
V1 = tf.Variable(tf.constant(1.0, shape=[1]), name="a1")
V2 = tf.Variable(tf.constant(2.0, shape=[1]), name="a2")
# result = V1 + V2

# 這裡要注意,checkpoint中的變數名的問題,不然就會出現問題
saver = tf.train.Saver({"Variable": V1, "Variable_1": V2})

# 載入儲存的模型,載入全部模型
with tf.Session() as sess:
    saver.restore(sess, "Saved_model/model.ckpt")
    print(sess.run(V1+V2))

上述關於檢視checkpoint檔案中的變數名的問題,請參考博文TensorFlow中檢視checkpoint檔案中的變數名和對應值