1. 程式人生 > >一份快速完整的Tensorflow模型儲存和恢復教程(譯)

一份快速完整的Tensorflow模型儲存和恢復教程(譯)

在本教程中,我將介紹:
- tensorflow模型是什麼樣子的?
- 如何儲存一個Tensorflow模型?
- 如何恢復一個Tensorflow模型用於預測/遷移學習?
- 如何匯入預訓練的模型進行微調和修改?

什麼是Tensorflow模型?

當你訓練好一個神經網路後,你會想儲存好你的模型便於以後使用並且用於生產。因此,什麼是Tensorflow模型?Tensorflow模型主要包含網路設計(或者網路圖)和訓練好的網路引數的值。所以Tensorflow模型有兩個主要的檔案:

a) Meta圖:
Meta圖是一個協議緩衝區(protocol buffer),它儲存了完整的Tensorflow圖;比如所有的變數、運算、集合等。這個檔案的副檔名是.meta

b) Checkpoint 檔案
這是一個二進位制檔案,它儲存了權重、偏置項、梯度以及其他所有的變數的取值,副檔名為.ckpt。但是, 從0.11版本開始,Tensorflow對改檔案做了點修改,checkpoint檔案不再是單個.ckpt檔案,而是如下兩個檔案:

mymodel.data-00000-of-00001
mymodel.index

其中, .data檔案包含了我們的訓練變數。除此之外,還有一個叫checkpoint的檔案,它保留了最新的checkpoint檔案的記錄。

總結一下,對於0.10之後的版本,tensorflow模型包含以下檔案:

model files
但對於0.11之前的版本,只包含三個檔案:

inception_v1.meta
inception_v1.ckpt
checkpoin

現在我們已經知道Tensorflow模型是什麼樣子的,讓我們繼續學習如何儲存模型。

儲存Tensorflow模型

假如你正在訓練一個用於影象分類的卷積神經網路(training a convolutional neural network for image classification)。通常你會先觀察損失和準確率,一旦發現網路收斂,就可以手動停止訓練過程或者直接訓練固定迭代次數。當訓練完成後,我們想要儲存所有的變數和網路圖便於以後使用。因此在Tensorflow中, 為了儲存網路圖和所有引數的值,我們應該建立tf.train.Saver()這個類的一個物件。

saver = tf.train.Saver()

記住Tensorflow變數只有在會話(session)中才能啟用。因此,你需要在會話中呼叫你剛建立的物件的儲存方法。

saver.save(sess, "my-test-model")

這裡,sess是一個session物件,“my-test-model”是你的模型名字。讓我們看一個完整的例子:

import tensorflow as tf
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my_test_model')

# This will save following files in Tensorflow v >= 0.11
# my_test_model.data-00000-of-00001
# my_test_model.index
# my_test_model.meta
# checkpoint

如果我們要在1000次迭代後儲存模型,我們應該在呼叫儲存方法時傳入步數計數:

saver.save(sess, "my_test_model", global_step=1000)

這會在模型名稱後加一個“-1000”並且會建立如下檔案:

my_test_model-1000.index
my_test_model-1000.meta
my_test_model-1000.data-00000-of-00001
checkpoint

假設在訓練過程中,我們要每1000次迭代儲存我們的模型,因此.meta檔案會在第一次(1000次迭代)時建立,我們並不需要之後每1000次迭代都儲存一遍這個檔案(我們在2000,3000…迭代時都不需要儲存這個檔案,因為這個檔案始終不變)。我們只需要儲存這個模型供以後使用,因為模型圖不會變化。所以,當我們不想重寫meta圖的時候,我們這樣寫:

saver.save(sess, "my-model", global_step=step, write_meta_graph=False)

如果你只想保留4個最新的模型並且在訓練過程中每過2小時儲存一次模型,你可以使用max_to_keep和keep_checkpoint_every_n_hours,就像這樣:

#saves a model every 2 hours and maximum 4 latest models are saved.
saver = tf.train.Saver(max_to_keep=4, keep_checkpoint_every_n_hours=2)

注意,如果我們在tf.train.Saver()中不指定任何東西,它將儲存所有的變數。要是我們不想儲存所有的變數而只是一部分變數。我們可以指定我們想要儲存的變數/集合。當建立tf.train.Saver()物件的時候,我們給它傳遞一個我們想要儲存的變數的字典列表。我們來看一個例子:

import tensorflow as tf
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
saver = tf.train.Saver([w1,w2])
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my_test_model',global_step=1000)

當需要的時候,這個程式碼可以用來儲存Tensorflow圖中的指定部分。

匯入預訓練模型

如果你想要用其他人預訓練的模型進行微調,需要做兩件事:

a) 建立網路
你可以寫python程式碼來手動建立和原來一樣的模型。但是,想想看,我們已經將原始網路儲存在了.meta檔案中,可以用tf.train.import()函式來重建網路:

saver = tf.train.import_meta_graph("my_test_model-1000.meta")

記住,import_meta_graph函式將只將定義在.meta檔案中的網路新增到當前的圖上。因此,它雖然幫你建立了額圖/網路,但我們還是需要匯入我們在這個圖上訓練好的模型的引數。

b) 匯入引數
我們可以呼叫由tf.train.Saver()建立的物件saver中的restore方法來恢復網路中的引數。

with tf.Session() as sess:
  new_saver = tf.train.import_meta_graph('my_test_model-1000.meta')
  new_saver.restore(sess, tf.train.latest_checkpoint('./'))

這樣,張量的值(如w1和w2)就被恢復並且可以訪問了:

with tf.Session() as sess:    
    saver = tf.train.import_meta_graph('my-model-1000.meta')
    saver.restore(sess,tf.train.latest_checkpoint('./'))
    print(sess.run('w1:0'))
##Model has been restored. Above statement will print the saved value of w1.

現在你已經理解了如何儲存和匯入Tensorflow模型。在下一節,我會介紹一個實際應用即匯入任何預訓練好的模型。

使用恢復的模型

現在你已經理解如何儲存和恢復Tensorflow模型,我們來寫一個實際的示例來恢復任何預訓練的模型並用它來預測、微調或者進一步訓練。無論你什麼時候用Tensorflow,你都會定義一個網路,它有一些樣本(訓練資料)和超引數(如學習率、迭代次數等)。通常用一個佔位符(placeholder)來將所有的訓練資料和超引數輸入給網路。下面我們用佔位符建立一個小型網路並儲存它。注意,當網路被儲存的時候,佔位符中的值並沒有被儲存。

import tensorflow as tf

#Prepare to feed input, i.e. feed_dict and placeholders
w1 = tf.placeholder("float", name="w1")
w2 = tf.placeholder("float", name="w2")
b1= tf.Variable(2.0,name="bias")
feed_dict ={w1:4,w2:8}

#Define a test operation that we will restore
w3 = tf.add(w1,w2)
w4 = tf.multiply(w3,b1,name="op_to_restore")
sess = tf.Session()
sess.run(tf.global_variables_initializer())

#Create a saver object which will save all the variables
saver = tf.train.Saver()

#Run the operation by feeding input
print sess.run(w4,feed_dict)
#Prints 24 which is sum of (w1+w2)*b1 

#Now, save the graph
saver.save(sess, 'my_test_model',global_step=1000)

當我們想要恢復這個網路的時候,我們不僅需要恢復圖和權重,還需要準備一個新的feed_dict來將訓練資料輸入到網路中。我們可以通過graph.get_tensor_by_name方法來引用這些儲存的運算和佔位符變數。

#How to access saved variable/Tensor/placeholders 
w1 = graph.get_tensor_by_name("w1:0")

## How to access saved operation
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")

如果我們只是想用不同的資料執行相同的網路,你可以方便地用feed_dict將新的資料送到網路中。

import tensorflow as tf

sess=tf.Session()    
#First let's load meta graph and restore weights
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))


# Now, let's access and create placeholders variables and
# create feed-dict to feed new data

graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict ={w1:13.0,w2:17.0}

#Now, access the op that you want to run. 
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")

print sess.run(op_to_restore,feed_dict)
#This will print 60 which is calculated 
#using new values of w1 and w2 and saved value of b1. 

要是你想在原來的計算圖中通過新增更多的層來增加更多的運算並且訓練。當然也可以實現,如下:

import tensorflow as tf

sess=tf.Session()    
#First let's load meta graph and restore weights
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))


# Now, let's access and create placeholders variables and
# create feed-dict to feed new data

graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict ={w1:13.0,w2:17.0}

#Now, access the op that you want to run. 
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")

#Add more to the current graph
add_on_op = tf.multiply(op_to_restore,2)

print sess.run(add_on_op,feed_dict)
#This will print 120.

但是,我們能夠只恢復原來圖中的一部分然後新增一些其它層來微調嗎?當然可以,只要通過graph.get_tensor_by_name()方法來獲取原網路的部分計算圖並在上面繼續建立新計算圖。這裡給出了一個實際的例子。我們用meta圖匯入了一個預訓練的vgg網路,然後將最後一層的輸出個數改成2用於微調新的資料。

......
......
saver = tf.train.import_meta_graph('vgg.meta')
# Access the graph
graph = tf.get_default_graph()
## Prepare the feed_dict for feeding data for fine-tuning 

#Access the appropriate output for fine-tuning
fc7= graph.get_tensor_by_name('fc7:0')

#use this if you only want to change gradients of the last layer
fc7 = tf.stop_gradient(fc7) # It's an identity function
fc7_shape= fc7.get_shape().as_list()

num_outputs=2
weights = tf.Variable(tf.truncated_normal([fc7_shape[3], num_outputs], stddev=0.05))
biases = tf.Variable(tf.constant(0.05, shape=[num_outputs]))
output = tf.matmul(fc7, weights) + biases
pred = tf.nn.softmax(output)

# Now, you run this with fine-tuning data in sess.run()

希望本文能夠讓你清楚地理解Tensorflow是如何被儲存和微調的。請在評論區自由分享你的問題或者疑問。