1. 程式人生 > >tensorflow學習之路----保存和提取數據

tensorflow學習之路----保存和提取數據

work flow pri works var save lob spa pan

#保存數據註意他只能保存變量,不能保存神經網絡的框架。
#保存數據的作用:保存權重有利於下一次的訓練,或者可以用這個數據進行識別
#np.arange():arange函數用於創建等差數組,使用頻率非常高
import tensorflow as tf
#註意:在保存變量的時候,一定要寫出他的類型即dtype
Weights = tf.Variable([[1,2,3],[3,4,5]],dtype=tf.float32,name="Weights")
baises = tf.Variable([[1,32,3]],dtype=tf.float32,name="baises")

saver = tf.train.Saver()#保存數據和提取數據都需要先創建一個對象,才能使用

init = tf.global_variables_initializer()
with tf.Session() as sess:
  sess.run(init)
  save_path = saver.save(sess,"E:\\workspace\\tensor\\tensor_data.ckpt")
  ‘‘‘
  保存數據使用的函數是:saver.save(sess,"保存路徑"),他可以返回保存路徑
  ‘‘‘
  print("save_path:",save_path)

#提取數據:這個寫在其他的文件的

‘‘‘
提取的數據要和保存的數據的shape和type一致


在提取數據的時候,已經不再需要初始化變量了,因為提取數據的時候,他會幫我們自動化初始化了
‘‘‘
import tensorflow as tf
import numpy as np
‘‘‘
提取數據的時候,變量的類型一定要相同
‘‘‘

W = tf.Variable(np.arange(6).reshape((2,3)),dtype=tf.float32,name="Weights")
b = tf.Variable(np.arange(3).reshape((1,3)),dtype=tf.float32,name="baises")

saver = tf.train.Saver()#保存數據和提取數據都需要先創建一個對象,才能使用

with tf.Session() as sess:
  saver.restore(sess,"E:\\workspace\\tensor\\tensor_data.ckpt")
  #提取到的數據會根據數據的類型自動分配到類型相同的數據變量中
  print("weights:",sess.run(W))
  print("baises:",sess.run(b))

tensorflow學習之路----保存和提取數據