1. 程式人生 > >TensorFlow的檔案儲存與讀取——variables_to_restore函式

TensorFlow的檔案儲存與讀取——variables_to_restore函式

轉,原創詳見: http://blog.csdn.net/sinat_29957455/article/details/78508793

variables_to_restore函式,是TensorFlow為滑動平均值提供。之前,也介紹過通過使用滑動平均值可以讓神經網路模型更加的健壯。我們也知道,其實在TensorFlow中,變數的滑動平均值都是由影子變數所維護的,如果你想要獲取變數的滑動平均值需要獲取的是影子變數而不是變數本身。

1、滑動平均值模型檔案的儲存

  1. import tensorflow as tf  
  2. if __name__ == "__main__":  
  3.     v = tf.Variable(0.,name=
    "v")  
  4.     #設定滑動平均模型的係數
  5.     ema = tf.train.ExponentialMovingAverage(0.99)  
  6.     #設定變數v使用滑動平均模型,tf.all_variables()設定所有變數
  7.     op = ema.apply([v])  
  8.     #獲取變數v的名字
  9.     print(v.name)  
  10.     #v:0
  11.     #建立一個儲存模型的物件
  12.     save = tf.train.Saver()  
  13.     sess = tf.Session()  
  14.     #初始化所有變數
  15.     init = tf.initialize_all_variables()  
  16.     sess.run(init)  
  17.     #給變數v重新賦值
  18.     sess.run(tf.assign(v,10))  
  19.     #應用平均滑動設定
  20.     sess.run(op)  
  21.     #儲存模型檔案
  22.     save.save(sess,"./model.ckpt")  
  23.     #輸出變數v之前的值和使用滑動平均模型之後的值
  24.     print(sess.run([v,ema.average(v)]))  
  25.     #[10.0, 0.099999905]
上面的程式碼,是如何來儲存一個滑動平均值的模型檔案,之前有介紹過滑動平均值和模型檔案的儲存,所以這裡就不再重複了。

2、滑動平均值模型檔案的讀取

  1. v = tf.Variable(
    1.,name="v")  
  2. #定義模型物件
  3. saver = tf.train.Saver({"v/ExponentialMovingAverage":v})  
  4. sess = tf.Session()  
  5. saver.restore(sess,"./model.ckpt")  
  6. print(sess.run(v))  
  7. #0.0999999
對於模型檔案的讀取,在上一篇部落格中有介紹過,這裡特別需要注意的一個地方就是,在使用tf.train.Saver函式中,所傳遞的模型引數是{"v/ExponentialMovingAverage":v}而不是{"v":v},如果你使用的是後面的引數,那麼你得到的結果將是10而不是0.09,那是因為後者獲取的是變數本身而不是影子變數。是不是感覺使用這種方式來讀取模型檔案的時候,還需要輸入一大串的變數名稱。

3、variables_to_restore函式的使用

  1. v = tf.Variable(1.,name="v")  
  2. #滑動模型的引數的大小並不會影響v的值
  3. ema = tf.train.ExponentialMovingAverage(0.99)  
  4. print(ema.variables_to_restore())  
  5. #{'v/ExponentialMovingAverage': <tf.Variable 'v:0' shape=() dtype=float32_ref>}
  6. sess = tf.Session()  
  7. saver = tf.train.Saver(ema.variables_to_restore())  
  8. saver.restore(sess,"./model.ckpt")  
  9. print(sess.run(v))  
  10. #0.0999999
通過使用variables_to_restore函式,可以使在載入模型的時候將影子變數直接對映到變數的本身,所以我們在獲取變數的滑動平均值的時候只需要獲取到變數的本身值而不需要去獲取影子變數。


相關推薦

TensorFlow檔案儲存讀取——variables_to_restore函式

轉,原創詳見: http://blog.csdn.net/sinat_29957455/article/details/78508793variables_to_restore函式,是TensorFlow為滑動平均值提供。之前,也介紹過通過使用滑動平均值可以讓神經網路模型更加的健壯。我們也知道,其實在Tens

csv檔案儲存讀取

csv是指逗號分隔值檔案格式,是以純文字形式來儲存表格資料。1.儲存為csv格式 把要儲存的表格中資料,按每一行存成一條記錄,然後將記錄新增到一個list裡面,寫入檔案的時候將每個記錄轉換成一串逗號分

json檔案儲存讀取

一、將資料儲存為.json檔案 1 model={} #資料 2 with open("./hmm.json",'w',encoding='utf-8') as json_file: 3 json.dump(model,json_file,ensure_a

TensorFlow之TFRecords檔案儲存讀取講解及程式碼實現

先聊一下tfrecord, 這是一種將影象資料和標籤放在一起的二進位制檔案,能更好的利用記憶體,TFRecords是TensorFlow中的設計的一種內建的檔案格式,優點有如下幾種:統一不同輸入檔案的框架它是更好的利用記憶體,更方便複製和移動(TFRecord壓縮的二進位制檔

TensorFlow模型檔案儲存讀取

一、模型檔案的儲存 在訓練一個TensorFlow模型之後,我們可以將訓練好的模型儲存成檔案,這樣可以方便下一次對新的資料進行預測的時候直接載入訓練好的模型即可獲得結果,下面通過TensorFlow提供的tf.train.Saver函式,將一個模型儲存成檔案,一般習慣性的將TensorFlow的模

tensorflow 訓練模型的儲存 讀取儲存的模型進行測試

在實際中,通常需要將經過大量訓練的較好模型引數儲存起來,在實際應用以訓練好的模型進行預測。 TensorFlow中提供了模型儲存的模組 tensorflow.train.Saver() 1. 匯入tensorflow模組        

二進位制檔案操作-儲存讀取

目標:實現資料儲存為二進位制檔案,然後通過二進位制檔案解析資料。 目標分為三個階段:1、將資料結構轉化為二進位制(至於資料是怎樣讀取進來的這個就不說了,因為方式比較多,但是讀取進來一定都會以特定的資料結構形式來儲存)。2、載入二進位制文字。3、載入為對應的資料結構。 階段

[Deep Learning] TensorFlow模型、引數的儲存讀取

大部分情況,我們會把訓練的網路儲存下來,用於後面的使用。或者,在當前網路下對引數進行一定程度的微調。 儲存變數 import tensorflow as tf # The file path to save the data save_file = './mode

Tensorflow+MNIST+CNN+模型儲存讀取

# coding: utf-8 import tensorflow as tf import numpy as np from utils import * from tensorflow.examples.tutorials.mnist import inpu

sklearn、TensorFlow、keras模型儲存讀取

一、sklearn模型儲存與讀取 1、儲存 from sklearn.externals import joblib from sklearn import svm X = [[0, 0], [1, 1]] y = [0, 1] clf = svm.SVC(

tensorflow中模型資料的儲存讀取

使用tensorflow過程中,訓練結束後我們需要用到模型檔案。有時候,我們可能也需要用到別人訓練好的模型,並在這個基礎上再次訓練。這時候我們需要掌握如何操作這些模型資料。看完本文,相信你一定會有收穫! 1 Tensorflow模型檔案 我們在checkpoint_dir

Android採用SharedPreferences方式進行檔案儲存讀取

工程目錄: 做軟體開發應該都知道,很多軟體會有配置檔案,裡面存放這程式運行當中的各個屬性值,由於其配置資訊並不多,如果採用資料庫來存放並不划算,因為資料庫連線跟操作等耗時大大影響了程式的效率,因此我們使用鍵值這種一一對應的關係來存放這些配置資訊。SharedPrefe

TensorFlow 存儲讀取

.get 最佳實踐 variable .py 通過 nod spa clas utf 之前通過CNN進行的MNIST訓練識別成功率已經很高了,不過每次運行都需要消耗很多的時間。在實際使用的時候,每次都要選經過訓練後在進行識別那就太不方便了。 所以我們學習一下如何將訓練習得的

localStorage 儲存讀取

element ora onclick win nbsp n) cti item fun window.onload=function(){ /** *其實也沒那麽難, *localStorage.setItem(json) 向本地儲存值

預處理後資料的儲存讀取

在機器學習中,一般都需要先對資料進行資料預處理工作。模型一般需要反覆的調參,因此可能需要多次使用預處理之後的資料,但是反覆進行資料的預處理工作是多餘的,我們可以將其儲存下來。 #用pickle模組將處理好的資料儲存成pickle格式,方便以後呼叫,即建立一個checkpoint # 儲存資料方便呼叫

Tensorflow如何儲存讀取model (即利用訓練好的模型測試新資料的準確度)

目標: cnn2d.py cnn2d_test.py 訓練網路,並儲存網路模型 讀取網路,用測試集測試準確度 直接貼程式碼:(只貼了相關部分,瀏覽完整程式碼請到GitHub) 1. cnn2

C#操作mysql的建表、基本儲存讀取

由於專案中要使用到資料庫來儲存一些基本資料,所以在此做了一些簡單的總結,也是最基本的mysql資料操作。 一、資料庫的儲存 專案中儲存資料時,需要根據操作者在介面上的操作來決定是否儲存(操作介面上有建立表格和儲存資料的按鈕),所以在儲存資料前,如果點選“建立表格”的話,會在mysql資料庫中

Pytorch模型的儲存讀取方法

方法一(推薦) 只儲存和載入模型的引數 # 儲存模型引數 def save_model(the_model, PATH): torch.save(the_model.state_dict(), PATH) # 載入模型引數 def load_model(PATH):

Python程式設計:JSON格式資料的儲存讀取

JSON格式資料的儲存與讀取 簡介 模組 json 可以將 Python 資料結構轉儲到檔案中,並在程式再次執行時載入該檔案的資料 使用json.dump()和json.load() #儲存:json.dump() #函式json.dump()接受兩個實參:要儲

機器學習:Python模型的儲存讀取

在學習機器學習的過程中,很多人都應該有這樣的疑問:模型訓練好了,以後要用怎麼辦呢?肯定不能再跑一邊資料,重新訓練模型以供使用,因為這樣太費時間。最好的辦法當然是,訓練和預測分開。訓練好模型後,將模型儲存好,當需要預測時,直接讀取模型檔案來呼叫,進行預測。 無論是sklearn還是Tensor