1. 程式人生 > >tensorflow 模型儲存與載入

tensorflow 模型儲存與載入

在訓練一個神經網路模型後,你會儲存這個模型未來使用或部署到產品中。所以,什麼是TF模型?TF模型基本包含網路設計或圖,與訓練得到的網路引數和變數。因此,TF模型具有兩個主要檔案:
a)meta圖
這是一個擬定的快取,包含了這個TF圖完整資訊;如所有變數等等。檔案以.meta結束。
b)檢查點檔案:
這個檔案是一個二進位制檔案,包含所有權重、偏移、梯度和所有其它儲存的變數的值。這個檔案以.ckpy結束。然而,TF已經在0.11版本後不再以這個形式了。轉而檔案包含如下檔案 :
mymodel.data-00000-of-00001
mymodel.index
.data檔案包含訓練變數。
除此之外 ,TF還包含一個名為“checkpoint”的檔案 ,儲存最後檢查點的檔案。
所以,綜上,TF模型包含如下檔案 :

  • my_test_model.data-00000-of-00001
  • my_test_model.index
  • my_test_model.meta
  • checkpoint**

2儲存一個TF模型
saver = tf.train.Saver()
注意,你需要在一個session中儲存這個模型
Python
1saver.save(sess, ‘my-model-name’)
完整的例子為:

import tensorflow as tf
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5
]), name='w2') saver = tf.train.Saver() sess = tf.Session() sess.run(tf.global_variables_initializer()) saver.save(sess, 'my_test_model')

如果是在TF模型迭代1000步後儲存這個模型,可以指定步數
saver.save(sess, ‘my_test_model’,global_step=1000)

3.載入一個預訓練的模型
a)建立網路
使用tf.train.import()函式載入以前儲存的網路。
saver = tf.train.import_meta_graph(‘my-model-1000.meta’)
注意,import_meta_graph將儲存在.meta檔案中的圖新增到當前的圖中。所以,建立了一個圖/網路,但是我們使用需要載入訓練的引數到這個圖中。

b)載入引數

'''restore tensor from model'''
w_out= self.graph.get_tensor_by_name('W:0')
b_out = self.graph.get_tensor_by_name('b:0')
_input = self.graph.get_tensor_by_name('x:0')
_out = self.graph.get_tensor_by_name('y:0')
y_pre_cls = self.graph.get_tensor_by_name('output:0')

注意問題1:
初始儲存位置如果為e:,則這個位置被儲存在checkpoint中
修改後:
model_checkpoint_path: “E:\tmp\newModel\crack_capcha.model-8100”
all_model_checkpoint_paths: “E:\tmp\newModel\crack_capcha.model-8100”

這個過程形象的描述
Technically, this is all you need to know to create a class-based neural network that defines the fit(X, Y) and predict(X) functions.

見stackoverFlow解釋
In( and After) TensorFlow version 0.11.0RC1, you can save and restore your model directly by calling tf.train.export_meta_graph and tf.train.import_meta_graph according tohttps://www.tensorflow.org/programmers_guide/meta_graph
save model:

w1 = tf.Variable(tf.truncated_normal(shape=[10]), name='w1')
w2 = tf.Variable(tf.truncated_normal(shape=[20]), name='w2')
tf.add_to_collection('vars', w1)
tf.add_to_collection('vars', w2)
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my-model')

**# save method will call export_meta_graph implicitly.
you will get saved graph files:my-model.meta**
restore model:

sess = tf.Session()
new_saver = tf.train.import_meta_graph('my-model.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('./'))
all_vars = tf.get_collection('vars')
for v in all_vars:
    v_ = sess.run(v)
    print(v_)

一個完整的例子:
self.session = tf.Session(graph=self.graph)

with self.graph.as_default():####預設圖與自定義圖的關係
    ckpt = tf.train.get_checkpoint_state(self.savefile)
       if ckpt and ckpt.model_checkpoint_path:
           print(''.join([ckpt.model_checkpoint_path,'.meta']))
           self.saver = tf.train.import_meta_graph(''.join([ckpt.model_checkpoint_path,'.meta']))
           self.saver.restore(self.session,ckpt.model_checkpoint_path)
       #print all variable
       for op in self.graph.get_operations():
       print(op.name, " " ,op.type)
       #返回模型中的tensor
       layers = [op.name for op in self.graph.get_operations() if op.type=='Conv2D' and 'import/' in op.name]
       layers = [op.name for op in self.graph.get_operations()]
       feature_nums = [int(self.graph.get_tensor_by_name(name+':0').get_shape()[-1]) for name in layers]
       for feature in feature_nums:
            print(feature)

     '''restore tensor from model'''
     w_out = self.graph.get_tensor_by_name('W:0')
     b_out = self.graph.get_tensor_by_name('b:0')
     _input = self.graph.get_tensor_by_name('x:0')
     _out = self.graph.get_tensor_by_name('y:0')
     y_pre_cls = self.graph.get_tensor_by_name('output:0')
     #self.session.run(tf.global_variables_initializer())   ####非常重要,不能新增這一句
        pred = self.session.run(y_pre_cls,feed_dict={_input:_X})
        return pred

中間有許多坑,但是成功的載入執行後,對模型的瞭解也加深了