1. 程式人生 > >tensorflow基礎(1)變數的建立、初始化、儲存與載入

tensorflow基礎(1)變數的建立、初始化、儲存與載入

廢話就不多說了,直接上乾貨。
1.變數的建立
tensoflow建立變數使用tf.Variable();需要指明變數的形狀

b = tf.Variable(tf.zeros([1]))
W = tf.Variable(tf.random_uniform([1, 2], -1.0, 1.0))

如這裡的w,b就是所要建立的變數。

2.初始化
變數的初始化,需要在變數操作執行前執行。

# 初始化變數
init = tf.global_variables_initializer()
.....
sess.run(init)#執行初始化操作

tf.global_variables_initializer()函式初始化了所有的變數。

3.變數的儲存與載入
這裡以一個例項為具體的模板進行講解。

#coding=utf-8
import tensorflow as tf
import numpy as np
import os
#判斷模型儲存路徑是否存在,不存在就建立
#if not os.path.exists('tem/'):
#    os.mkdir('tem/')
# 使用 NumPy 生成假資料(phony data), 總共 100 個點.
x_data = np.float32(np.random.rand(2, 100)) # 隨機輸入
y_data = np.dot([0.100, 0.200], x_data) + 0.300
# 構造一個線性模型 # b = tf.Variable(tf.zeros([1])) W = tf.Variable(tf.random_uniform([1, 2], -1.0, 1.0)) y = tf.matmul(W, x_data) + b # 最小化方差 loss = tf.reduce_mean(tf.square(y - y_data)) optimizer = tf.train.GradientDescentOptimizer(0.5) train_op = optimizer.minimize(loss) # 初始化變數 #merged_summary_op = tf.summary.merge_all()
init = tf.global_variables_initializer() saver = tf.train.Saver() # 啟動圖 (graph) with tf.Session() as sess: #summary_writer = #tf.summary.FileWriter('tem/mnist_logs', sess.graph) #if p #sess.run(init) # 擬合平面 #path =os.path.join("", "tem/model.ckpt") ckpt = tf.train.get_checkpoint_state('tem') if ckpt != None: print(ckpt.model_checkpoint_path) saver.restore(sess, ckpt.model_checkpoint_path) else: sess.run(init) #saver.restore(sess,path) path = os.path.join("", "tem/model.ckpt") for step in range(0, 201): sess.run(train_op) #summary_str = #sess.run(merged_summary_op,feed_dict=#{x_data:x_data,y_data:y_data}) #summary_writer.add_summary(summary_str, step) #summary_writer.flush() if step % 20 == 0: print (step, sess.run(W), sess.run(b)) saver.save(sess, path,global_step=step)

儲存和恢復模型的方法是使用tf.train.Saver物件,預設儲存所有變數,但可以手動傳入要儲存的變數。
saver.save(session,path,global_step)是儲存模型,傳入的是sess,儲存的路徑,以及global_step=step,且必須先建立tem資料夾。

恢復模型的方法類似。saver.restore();
統一的框架,用於解決要麼存在模型,要麼沒有模型進行執行初始化。

ckpt = tf.train.get_checkpoint_state('tem')
    if ckpt != None:
        print(ckpt.model_checkpoint_path)
        saver.restore(sess, ckpt.model_checkpoint_path)
    else:
        sess.run(init)

恩,純屬個人見解,寫的不好,請給予批評指正。