1. 程式人生 > >在tensorflow中儲存模型引數

在tensorflow中儲存模型引數

想要儲存訓練之後得到的神經網路引數,一般有兩種辦法。

第一種,可以將tensor物件轉換為numpy陣列進行儲存。

即,

numpy.savetxt('weight.txt', weight.eval())

第二種,是利用tensorflow自帶的Saver物件。

import tensorflow as tf ##################################################3 w1 = tf.Variable(tf.constant(1.0), name='w1') w2 = tf.Variable(tf.constant(2.0), name='w2'
) tf.add_to_collection('vars', w1) tf.add_to_collection('vars', w2) saver = tf.train.Saver()
with tf.Session() as sess:     sess.run(tf.global_variables_initializer())     w1 = tf.add(w1, w2)     saver.save(sess, './my-model.ckpt')

上面的程式碼中,建立了容器vars。它收集了tensor變數w1和w2。之後,tensorflow將這一容器儲存。

在session中執行,就能將資料儲存到tensorflow建立的幾個檔案中。

上面的程式碼執行結束後,當前目錄下出現四個檔案:

my-model.ckpt.meta

my-model.ckpt.data-*

my-model.ckpt.index

checkpoint

利用這四個檔案就能恢復出 w1和w2這兩個變數。

with tf.Session() as sess:     new_saver = tf.train.import_meta_graph('my-model.ckpt.meta')     new_saver.restore(sess, tf.train.latest_checkpoint('./'))     all_vars = tf.get_collection('vars'
)     print(all_vars)     for v in all_vars:         print(v)         print(v.name)         v_ = v.eval() # sess.run(v)         print(v_)

執行結果為:


[<tf.Tensor 'w1:0' shape=() dtype=float32_ref><tf.Tensor 'w2:0' shape=() dtype=float32_ref>] Tensor("w1:0"shape=(), dtype=float32_ref) w1:0 1.0 Tensor("w2:0"shape=(), dtype=float32_ref) w2:0 2.0