1. 程式人生 > >[Deep Learning] TensorFlow模型、引數的儲存與讀取

[Deep Learning] TensorFlow模型、引數的儲存與讀取

大部分情況,我們會把訓練的網路儲存下來,用於後面的使用。或者,在當前網路下對引數進行一定程度的微調。

儲存變數

import tensorflow as tf

# The file path to save the data
save_file = './model.ckpt'

# Two Tensor Variables: weights and bias
weights = tf.Variable(tf.truncated_normal([2, 3]))
bias = tf.Variable(tf.truncated_normal([3]))

# Class used to save and/or restore Tensor Variables
saver = tf.train.Saver() with tf.Session() as sess: # Initialize all the Variables sess.run(tf.global_variables_initializer()) # Show the values of weights and bias print('Weights:') print(sess.run(weights)) print('Bias:') print(sess.run(bias)) # Save the model saver.save(sess, save_file)

Weights:

[[-0.97990924 1.03016174 0.74119264]

[-0.82581609 -0.07361362 -0.86653847]]

Bias:

[ 1.62978125 -0.37812829 0.64723819]

讀取變數
# Remove the previous weights and bias
tf.reset_default_graph()

# Two Variables: weights and bias
weights = tf.Variable(tf.truncated_normal([2, 3]))
bias = tf.Variable(tf.truncated_normal([3
])) # Class used to save and/or restore Tensor Variables saver = tf.train.Saver() with tf.Session() as sess: # Load the weights and bias saver.restore(sess, save_file) # Show the values of weights and bias print('Weight:') print(sess.run(weights)) print('Bias:') print(sess.run(bias))