1. 程式人生 > >TensorFlow儲存、載入模型引數 | 原理描述及踩坑經驗總結

TensorFlow儲存、載入模型引數 | 原理描述及踩坑經驗總結

寫在前面

我之前使用的LSTM計算單元是根據其前向傳播的計算公式手動實現的,這兩天想要和TensorFlow自帶的tf.nn.rnn_cell.BasicLSTMCell()比較一下,看看哪個訓練速度快一些。在使用tf.nn.rnn_cell.BasicLSTMCell()進行建模的時候,遇到了模型儲存、載入的問題。

查找了一些博主的經驗,再加上自己摸索,在這裡做個筆記,總結經驗。其中關鍵要素有以下3點:

1.需要儲存哪些變數(tensor),就要給哪些變數取名字(即name='XXXXX')。

2.將tf.train.Saver()與需要儲存的變數(tensor)定義在一個函式裡,否則儲存會出錯。

3.載入模型的時候,先載入圖,再載入變數(tensor)。

下面通過例項進行描述。

 

模型儲存

tf.train.Saver()可以自動儲存變數和計算圖。

儲存前注意!!!需要對要儲存的變數命名,即屬性中的name=XXX

下面是使用tf.nn.rnn_cell.BasicLSTMCell()自建的一個LSTM_Cel

 1 class LSTM_Cell(object):
 2 # train_data 格式示例,batch_size*num_steps*input_dim   批大小*時間視窗長度*單時間節點輸入維度
 3 # train_label格式示例,batch_size*1  # TODO 該模型緊輸出一維結果。
 4 # input_dim 格式 int, 輸入資料在單時間節點上的維度
 5 # num_nodes 神經元數目/維度
 6     def __init__(self, train_data, train_label, input_dim, batch_size=10, num_nodes=64):
 7         tf.reset_default_graph()
 8         self.num_nodes = num_nodes
 9         self.input_dim = input_dim
10         self.train_data = train_data
11         self.train_label = train_label
12         self.batch_size = batch_size
13         gpu_options=tf.GPUOptions(per_process_gpu_memory_fraction=0.666)
14         self.sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
15 
16     def loss_func(self,lr=0.001):
17         self.w = tf.Variable(tf.truncated_normal([self.num_nodes, 1], -0.1, 0.1),name='w')   # 1 是指輸出維度,這裡預測一個值,因此維度是1
18         self.b = tf.Variable(tf.zeros([1]),name='b')
19         self.batch_in = tf.placeholder(tf.float32, [None, self.train_data.shape[1], self.input_dim],name='batch_in')
20         self.batch_out = tf.placeholder(tf.float32, [None, 1],name='batch_out')
21         lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(self.num_nodes,forget_bias=1.0,state_is_tuple=True)
22         # init_state = lstm_cell.zero_state(self.batch_in[0],dtype=tf.float32)
23         output, final_state = tf.nn.dynamic_rnn(lstm_cell, self.batch_in, time_major=False, dtype=tf.float32)  # initial_state=init_state,
24         self.y_pre = tf.nn.sigmoid(tf.matmul(final_state[1], self.w) + self.b,name="y_pre")
25         self.mse = tf.sqrt(tf.reduce_mean(tf.square(self.y_pre-self.batch_out)),name='mse')
26         self.cross_entropy = -tf.reduce_mean(self.batch_out * tf.log(self.y_pre),name='cross_entropy')
27         self.train_op = tf.train.GradientDescentOptimizer(lr).minimize(self.mse)
28         self.saver = tf.train.Saver()
29 
30     def train_model(self,savepath,epochs=1000):
31         self.sess.run(tf.global_variables_initializer())
32         for i in range(epochs):
33             for j in range(int(len(self.train_data)/self.batch_size)):
34                 batch_i = self.train_data[j*self.batch_size:(j+1)*self.batch_size]
35                 batch_o = self.train_label[j*self.batch_size:(j+1)*self.batch_size]
36                 self.sess.run(self.train_op, feed_dict={self.batch_in:batch_i, \
37                         self.batch_out:batch_o.reshape(self.batch_size,1)})
38             if (i+1)%200==0:
39                 print('epoch:%d'%(i+1),self.sess.run(self.mse,feed_dict={self.batch_in:batch_i, \
40                     self.batch_out:batch_o.reshape(self.batch_size,1)}))
41         save_path = self.saver.save(self.sess, savepath)
42         print("模型保存於: ", save_path)

在LSTM_Cell類中,建構函式定義了一些固定引數以及TensorFlow會話(tf.Session()),而我們所要儲存的變數(tensor)都在loss_func()函式中定義。包括:

①最後一個全連線層的w和b;

②輸入、輸出變數的佔位符batch_in,batch_out;

③LSTM單元的計算過程;

④計算最終計算結果y_pre,均方根誤差mse,交叉熵計算結果cross_entropy,使用隨機梯度下降的訓練步驟train_op;

⑤儲存器tf.train.Saver()。

在本例中,只有變數’w’, ’b’, ’batch_in’, ’batch_out’, ’y_pre’, ’mse’, ’cross_entropy’在屬性中有過命名,會被儲存下來。

這裡,tf.train.Saver()只能儲存本函式(即loss_func)中定義的變數(tensor)。

train_model()函式實現訓練過程,並呼叫self.saver.save(self.sess, savepath)來對模型及命名了的變數(tensor)進行儲存。

下面是呼叫LSTM_Cell類進行訓練並儲存模型的程式碼:

1 # 初始化LSTM類
2 lstm_obj = LSTM_Cell(sample_input,sample_output,input_dim=1,batch_size=_batch_size,num_nodes=hidden_size)
3 lstm_obj.loss_func(lr)  # 構建計算圖
4 # TODO 訓練
5 lstm_obj.train_model(savepath=saved_path,epochs=epochs)

其中,儲存路徑為

1 saved_path = "./standard_LSTM/models/Basic_LSTM_TF_models/59model.ckpt"

最終得到的儲存結果為下方4個檔案(暫時無視兩個png圖片)

 

 

 .meta文件是計算圖儲存的位置,.data是引數資料,後面的00000-of-00001是模型的版本號。

 

模型載入

載入困擾了我很久,後面經過摸索才知道有兩個關鍵部分,一個是計算圖的載入,一個是變數的載入,兩者缺一不可。

LSTM_Cell類中,載入函式(load_model)定義如下。

 1     def load_model(self,savepath):
 2         len_last = len(savepath.split('/')[-1])
 3         self.saver = tf.train.import_meta_graph(savepath+'.meta')
 4         self.saver.restore(self.sess,tf.train.latest_checkpoint(savepath[:-len_last]))  # 載入最後一個模型
 5         self.graph = tf.get_default_graph()
 6         tensor_name_list = [tensor.name for tensor in self.graph.as_graph_def().node]
 7         self.w = self.graph.get_tensor_by_name('w:0')
 8         self.b = self.graph.get_tensor_by_name('b:0')
 9         self.batch_in = self.graph.get_tensor_by_name("batch_in:0")  
10         self.batch_out = self.graph.get_tensor_by_name("batch_out:0")
11         self.y_pre = self.graph.get_tensor_by_name('y_pre:0')
12         self.mse = self.graph.get_tensor_by_name('mse:0') 
13         self.cross_entropy = self.graph.get_tensor_by_name('cross_entropy:0')

首先定義一個self.saver,來輔助載入圖及變數。

第一步載入圖,即tf.train.import_meta_graph(savepath+'.meta'),就是載入上圖中的 59model.ckpt.meta

saver.restore()函式將模型引數進行載入,savepath[:-len_last]是指儲存模型的資料夾路徑,即"./standard_LSTM/models/Basic_LSTM_TF_models/" ,將模型載入到預設的計算圖中(default_graph)。

此時,各變數(即tensor)已經在計算圖中了,但要正常呼叫,還需要從圖中取出並將其設定成變數。

具體方法是先取得預設的計算圖self.graph,再通過get_tensor_by_name()方法將tensor例項化,每個tensor的名稱與模型儲存時name=”XXX”的名稱相同,並且後方需要加上:<index>,不重名的情況下這個index一般是0。

也有博主說tensor的名稱可以在tensor_name_list中檢視到,但我打印出來後發現這個list太長,不大實用。

這時候就載入模型完畢了,可以呼叫self.sess對self.y_pre、self.mse進行計算。

示例計算如下:

1     def predict_next_one(self,batch_i):  # batch_i長度 為樣本時間序列長度
2         temp = self.sess.run(self.y_pre,feed_dict={self.batch_in:batch_i.reshape(1,len(batch_i),1)})
3         return temp[0][0]

外部的呼叫方法如下,(建構函式後就不用使用loss_func構建計算過程了,直接載入模型就行。)

1 # 初始化LSTM類
2 lstm_obj = LSTM_Cell(sample_input,sample_output,input_dim=1,batch_size=_batch_size,num_nodes=hidden_size)
3 # TODO 載入模型
4 lstm_obj.load_model(savepath=saved_path)

下面兩個圖是訓練完後直接預測以及載入模型再預測的結果,可以看出模型載入後,計算結果與之前一致。

&n