1. 程式人生 > >Tensorflow儲存和讀取模型

Tensorflow儲存和讀取模型

1.概述

將深度學習應用到工業領域實時處理資料時,我們需要訓練好的模型實時計算,那就需要儲存和讀取模型,tensorflow目前提供了這方面的初步工作。因為tensorflow只能儲存變數而不是儲存整個網路,所以在提取模型時,我們還需要重新第一網路結構。

2.程式碼演示

(1)儲存

import tensorflow as tf
import numpy as np
#儲存時dtype型別要一致,一般使用float32,另外要定義變數名
W = tf.Variable([[1,2,3],[3,4,5]], dtype=tf.float32, name='weights')
b = tf.Variable([[1
,2,3]], dtype=tf.float32, name='biases') # 初始化所有變數 init = tf.initialize_all_variables() # 構建儲存模型 saver = tf.train.Saver() #啟動 with tf.Session() as sess: sess.run(init) #定義儲存路徑 save_path = saver.save(sess, "/Users/chunsoft/Desktop/savemodel/save_test.ckpt") print("Save to path: ", save_path)

(2)讀取

import tensorflow as tf
import numpy as np
# 重新定義相同的變數的dtype和shape
W = tf.Variable(np.arange(6).reshape((2, 3)), dtype=tf.float32, name="weights")
b = tf.Variable(np.arange(3).reshape((1, 3)), dtype=tf.float32, name="biases")

# 不需要初始化

saver = tf.train.Saver()
with tf.Session() as sess:
    saver.restore(sess, "/Users/chunsoft/Desktop/savemodel/save_test.ckpt"
) print("weights:", sess.run(W)) print("biases:", sess.run(b))

讀取和儲存還是很方便的,期待tensorflow版本更新後,能夠儲存整個網路。