1. 程式人生 > >tensorflow(三) 模型儲存

tensorflow(三) 模型儲存

tensorflow最簡單的儲存與載入模型的方法是Saver物件(存放在tensorflow.train)。構造器給graph所有的變數,或者定義在列表中的變數,新增save和restore的操作,分別為儲存和載入。變數儲存在二進位制的檔案中,主要包含的是從變數名到tensor值的對映關係。

儲存變數
通過下面的一段程式碼穿件Saver物件來管理模型中的變數(預設情況下是所有的變數,也可以自行選擇)。

import tensorflow as tf
v1 = tf.Variable(tf.random_normal([1,2]), name="v1")
v2 = tf.Variable
(tf.random_normal([2,3]), name="v2") init_op = tf.initialize_all_variables() saver = tf.train.Saver() with tf.Session() as sess: sess.run(init_op) saver_path = saver.save(sess, "/home/yang/data/model.ckpt") print "Model saved in file: ", saver_path

恢復變數
用同一個Saver物件來恢復變數,注意,當你從檔案恢復變數是,不需要對它進行初始化,否則會報錯。

import tensorflow as tf
v1 = tf.Variable(tf.random_normal([1,2]), name="v1")
v2 = tf.Variable(tf.random_normal([2,3]), name="v2")
saver = tf.train.Saver()

with tf.Session() as sess:
    saver.restore(sess, "/home/yang/data/model.ckpt")
    print "Model restored"