1. 程式人生 > >如何檢視Tensoflow模型中已儲存的引數

如何檢視Tensoflow模型中已儲存的引數

1.儲存和讀取

1.1 儲存

import tensorflow as tf


aa = tf.Variable(tf.random_normal(shape=[3, 4], dtype=tf.float32), name='var')
saver = tf.train.Saver()
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(aa))
    # Step 1  儲存
    saver.save(sess,'./ttt')
    
>>
[[ 0.8604646 0.45935377 -0.24135743 -2.2841513 ] [-0.20688622 0.60574555 -0.26031223 -0.441991 ] [-0.22254886 1.4805079 -1.7360271 1.1423918 ]]

這兒我們定義了一個name=var的變數(隨便說一句aa這類名稱是我們寫程式時用以區分各個變數之間的依據,換句話說是給我們自己看的;而var這個名字是tensorflow計算圖上用來區分各個變數和操作的依據),並且將其進行了儲存。

1.2 讀取

說到讀取,就有兩個方面了:第一,知道引數的名字(上面的var)時之間讀取該變數;第二,不知道引數的名稱時可以先打出所有變數,然後找你所要變數對應的名稱再按名讀取就行。

#----------------------直接按名讀取---------------------------

import tensorflow as tf
aa = tf.Variable(tf.random_normal(shape=[3, 4], dtype=tf.float32), name='var')
saver = tf.train.Saver()
with tf.Session() as sess:
    print(sess.run(tf.get_default_graph().get_tensor_by_name('var:0')))
    
>>
[[ 0.8604646 0.45935377 -0.24135743 -2.2841513 ] [-0.20688622 0.60574555 -0.26031223 -0.441991 ] [-0.22254886 1.4805079 -1.7360271 1.1423918 ]]

#----------------------檢視所有變數名---------------------------
import tensorflow as tf
aa = tf.Variable(tf.random_normal(shape=[3, 4], dtype=tf.float32), name='var')
saver = tf.train.Saver()
with tf.Session() as sess:

    saver.restore(sess, './ttt')
    var_list = [v.name for v in tf.global_variables()]
    print(var_list)
    print(sess.run(var_list))
    
['var:0']
[array([[ 0.8604646 ,  0.45935377, -0.24135743, -2.2841513 ],
       [-0.20688622,  0.60574555, -0.26031223, -0.441991  ],
       [-0.22254886,  1.4805079 , -1.7360271 ,  1.1423918 ]],
      dtype=float32)]

可以看到讀取變數後的輸出值和儲存時的一樣。

2.哪些變數能夠儲存

其實saver.save()在儲存引數的時候是有選擇的(我說的選擇不是通過save()引數裡面控制的引數),看例子:

aa = tf.Variable(tf.random_normal(shape=[3, 4], dtype=tf.float32), name='aa')
bb = tf.constant(0,dtype=tf.float32,name='bb')
cc = tf.zeros(shape=[5],dtype=tf.float32,name='cc')
dd = tf.Variable(tf.zeros(shape=[5],dtype=tf.float32),name='cc')
ee = tf.Variable(tf.zeros(shape=[5],dtype=tf.float32),name='ee',trainable=False)

i = 10

saver = tf.train.Saver()
with tf.Session() as sess:
    # Step 1  儲存
    sess.run(tf.global_variables_initializer())
    saver.save(sess,'./ttt')

這兒我們一共定義了6個引數,其中有三個tensor變數(aa,dd,ee)和兩個tensor常量(bb,cc),和一個普通變數,我們來看一下哪些引數儲存成功:

aa = tf.Variable(tf.random_normal(shape=[3, 4], dtype=tf.float32), name='aa')
bb = tf.constant(0,dtype=tf.float32,name='bb')
cc = tf.zeros(shape=[5],dtype=tf.float32,name='cc')
dd = tf.Variable(tf.zeros(shape=[5],dtype=tf.float32),name='cc')
ee = tf.Variable(tf.zeros(shape=[5],dtype=tf.float32),name='ee',trainable=False)
i = 10

saver = tf.train.Saver()
with tf.Session() as sess:
    saver.restore(sess, './ttt')
    var_list = [v.name for v in tf.global_variables()]
    # print(sess.run(list_before_train))
    print(var_list)

>>

['aa:0', 'cc_1:0', 'ee:0']

我們可以看到,這兒只有3個變數被儲存成功,aa,ee,cc_1。明顯,aa指得就是第1行程式碼定義得變數,ee指得就是第5行程式碼,那麼cc_1指得是第3行還是第4行呢? 指得是第4行,這也印證tensorflow內部是通過name='var'這個引數來區分的。

由此我們可以得出:saver.save()只儲存tensor變數,也就是tf.Variable()定義的變數,其它量包括tensor常量都是不被儲存的。

3.網路模型的引數也能這樣來儲存麼?

答案是:能!

這裡以一個rnn cell按時間維度展開為例:

#--------------------------------------------儲存--------------------------------------

import tensorflow as tf
import numpy as np

output_size = 5
batch_size = 4
time_step = 3
dim = 3
cell = tf.nn.rnn_cell.BasicRNNCell(num_units=output_size)
inputs = tf.placeholder(dtype=tf.float32, shape=[time_step, batch_size, dim])
h0 = cell.zero_state(batch_size=batch_size, dtype=tf.float32)
X = np.array([[[1, 2, 1], [2, 0, 0], [2, 1, 0], [1, 1, 0]],  # x1
              [[1, 2, 1], [2, 0, 0], [2, 1, 0], [1, 1, 0]],  # x2
              [[1, 2, 1], [2, 0, 0], [2, 1, 0], [1, 1, 0]]])  # x3
outputs, final_state = tf.nn.dynamic_rnn(cell, inputs, initial_state=h0, time_major=True)

saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    # Step 1  儲存
    saver.save(sess,'./ttt')

#-------------------------------------------讀取--------------------------------------


import tensorflow as tf
import numpy as np

output_size = 5
batch_size = 4
time_step = 3
dim = 3
cell = tf.nn.rnn_cell.BasicRNNCell(num_units=output_size)
inputs = tf.placeholder(dtype=tf.float32, shape=[time_step, batch_size, dim])
h0 = cell.zero_state(batch_size=batch_size, dtype=tf.float32)
X = np.array([[[1, 2, 1], [2, 0, 0], [2, 1, 0], [1, 1, 0]],  # x1
              [[1, 2, 1], [2, 0, 0], [2, 1, 0], [1, 1, 0]],  # x2
              [[1, 2, 1], [2, 0, 0], [2, 1, 0], [1, 1, 0]]])  # x3
outputs, final_state = tf.nn.dynamic_rnn(cell, inputs, initial_state=h0, time_major=True)

saver = tf.train.Saver()

with tf.Session() as sess:
    saver.restore(sess, './ttt')
    var_list = [v.name for v in tf.global_variables()]
    print(var_list)
    
>>
['rnn/basic_rnn_cell/kernel:0', 'rnn/basic_rnn_cell/bias:0']

既然都儲存了,那為什麼這兒只有兩個變數呢?那是因為tensorflow內部在計算時為了方便或是更快,把所有的weight和bias都疊在一起了,具體參見此處!

另外說明一下:

在網上看到很多人提問LSTM訓練好的模型“儲存不了”。為什麼會覺得儲存不了呢? 因為在當訓練到某個時候loss已經很低了,當stop後再次載入最新幾個模型時都發現loss急劇升高,因此就會決定是因為模型的引數沒有儲存成功而導致的,因為本人在這兩天也出現了這個問題。於是網上各種搜查LSTM模型儲存的方法,試了一大堆依舊無效,後來終於發現是由於同一個函式在不同平臺(windwo,linux)上的處理結果居然不一樣,導致預處理後的訓練集一直在變而導致的!

另外,你還可以通過在每次儲存LSTM模型時,打印出其中某個引數的具體值,然後手動stop;當你再次載入模型時,立馬輸出同一個變數,對比一下是否相同,如果相同則說明儲存成功。依照我自己的實驗來看,兩者是相同的。


print('------儲存時的值----->')
print(sess.run(tf.get_default_graph().get_tensor_by_name(
'rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel:0'))[:5, :4])
                      

print('載入時的值----------->')
print(sess.run(tf.get_default_graph().get_tensor_by_name(
'rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel:0'))[:5,:4])


#   同一個變數,相同部分的值