1. 程式人生 > >TensorFlow利用saver儲存和提取引數

TensorFlow利用saver儲存和提取引數

在訓練迴圈中,定期呼叫 saver.save() 方法,向資料夾中寫入包含了當前模型中所有可訓練變數的 checkpoint 檔案。
saver.save(sess, FLAGS.train_dir, global_step=step)

global_step是訓練的第幾步

儲存引數:

import tensorflow as tf

W = tf.Variable([[1, 2, 3]], dtype=tf.float32)
b = tf.Variable([[1]], dtype=tf.float32)

saver = tf.train.Saver()

sess = tf.InteractiveSession()
tf.global_variables_initializer().run()
# 必須要指定資料夾,儲存到ckpt檔案
save_path = saver.save(sess, "winycg/1.ckpt")
print(save_path)

一次 saver.save() 後可以在資料夾中看到新增的四個檔案,實際上每呼叫一次儲存操作會建立後3個數據檔案並建立一個檢查點(checkpoint)檔案,簡單理解就是權重等引數被儲存到 .chkp.data 檔案中,以字典的形式;圖和元資料被儲存到 .chkp.meta 檔案中,可以被 tf.train.import_meta_graph 載入到當前預設的圖。

讀取引數:

import tensorflow as tf
import numpy as np

W = tf.Variable(np.arange(3).reshape(1, 3), dtype=tf.float32)
b = tf.Variable(np.arange(1).reshape(1, 1), dtype=tf.float32)

saver = tf.train.Saver()

sess = tf.InteractiveSession()
# 讀取引數時不需要global_variables_initializer()
save_path = saver.restore(sess, "parameter/1.ckpt")
print("weights:", sess.run(W))
print("bias:", sess.run(b))

weights: [[ 1.  2.  3.]]
bias: [[ 1.]]

相關推薦

TensorFlow利用saver儲存提取引數

在訓練迴圈中,定期呼叫 saver.save() 方法,向資料夾中寫入包含了當前模型中所有可訓練變數的 checkpoint 檔案。 saver.save(sess, FLAGS.train_dir, global_step=step) global_step是訓練的第幾步

tensorflow saver 儲存恢復指定 tensor

在實踐中經常會遇到這樣的情況: 1, 用簡單的模型預訓練引數 2, 把預訓練的引數匯入複雜的模型後訓練複雜的模型 這時就產生一個問題:            &nb

TensorFlow儲存提取模型

一、模型的儲存 將訓練好的模型引數儲存起來,以便以後進行驗證或測試,這是我們經常要做的事情。tf裡面提供模型儲存的是tf.train.Saver()模組。 1、模型儲存,先要建立一個Saver物件:如 saver=tf.train.Saver() __init__( var_

tensorflow訓練權重儲存呼叫——tf.saver()

1、訓練權重並儲存 #!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Fri Jul 28 15:18:58 2017 @author: dell """ import tensor

TensorFlow模型檔案儲存讀取

一、模型檔案的儲存 在訓練一個TensorFlow模型之後,我們可以將訓練好的模型儲存成檔案,這樣可以方便下一次對新的資料進行預測的時候直接載入訓練好的模型即可獲得結果,下面通過TensorFlow提供的tf.train.Saver函式,將一個模型儲存成檔案,一般習慣性的將TensorFlow的模

theano學習之模型的儲存提取

今天學習如何儲存神經網路,以方便日後可以直接提取使用。 儲存的方式是我們可以先把神經網路的引數,比如說 weights 還有 bias 儲存起來,再重新定義神經網路的結構,使用模型的時候需要把引數 set 到結構中去。 儲存和提取的方法是利用 shared 變數的 get 功能,拿出變數值儲

簡單完整地講解tensorflow模型的儲存恢復

http://blog.csdn.net/liangyihuai/article/details/78515913 在本教程主要講到: 1. 什麼是Tensorflow模型? 2. 如何儲存Tensorflow模式? 3. 如何還原預測/遷移學習Tensorflow模型? 4. 如

tensorflow】模型儲存恢復

  tensorflow中,模型的儲存和恢復使用tf.train.Saver類,模型儲存使用該類的 save 方法。模型恢復使用restore 方法。 模型儲存   模型儲存使用tf.train.Saver.save()方法。以saver.save(se

TensorFlowSaver儲存

一、Saver的介紹 有時可能只需要儲存或者載入部分變數。 比如,可能有一個之前訓練好的5層神經網路模型,但現在想寫一個6層的神經網路,那麼可以將之前5層神經網路中的引數直接載入到新的模型,而僅僅

TensorFlow 利用Dataset讀取構建資料

TensorFlow資料讀取方式: 利用placeholder讀取記憶體資料 Dataset API同時支援從記憶體和硬碟的讀取,相比之前的兩種方法在語法上更加簡潔易懂 Dataset建立和讀取資料集 Google官方給出的Dataset

TensorFlow學習筆記(六) tensorflow 中的儲存訓練的引數以及載入引數檢測test集

如何儲存訓練好的引數以前面練習的一個小例子,來儲存訓練好的引數:import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data from tensorflow import

TensorFlow訓練的模型引數(WB)利用np.savez()儲存用cnpy庫C++讀取

一、背景 大家經常會在Python環境下用TensorFlow訓練自己的深度學習模型。為了日後能夠使用訓練好的模型,在Python環境下,TensorFlow提供了 tf.train.Saver類,用來儲存模型。這個類提供了一整套函式用於方便的儲存和恢復模型! 但實際情

tensorflow儲存模型、載入模型提取模型引數特徵圖

1.tf.train.latest_checkpoint('./model_data/')這一句最終返回的是一個字串,比如'./model_data/model-99991'這個方法本身還會做相應的檢查,比如checkpoint中最新的模型model_checkpoint_p

Tensorflow之模型引數Saver儲存讀取

一、Saver儲存 import tensorflow as tf import numpy as np #定義W和b W = tf.Variable([[1,2,3],[3,5,6]],dtype = tf.float32,name = 'weight') b = tf.Variable([

tensorflow模型的儲存與恢復(tf.train.Saver()saver.restore()方法的運用)

注意: 在tensorflow中,一旦有涉及到檔案路徑的相關操作,請確保檔案路徑上所有目錄名為英文!!!!否則可能會出現奇怪的錯誤!! TensorFlow模型的儲存與恢復(使用tf.train.Saver()和saver.restore()): 首先我們需要建立一個用來儲存模型的物

Tensorflow模型引數Saver儲存讀取

一、Saver儲存 import tensorflow as tf import numpy as np #定義W和b W = tf.Variable([[1,2,3],[3,5,6]],dtype = tf.float32,name = 'weight') b = tf

tensorflow學習之路----保存提取數據

work flow pri works var save lob spa pan #保存數據註意他只能保存變量,不能保存神經網絡的框架。#保存數據的作用:保存權重有利於下一次的訓練,或者可以用這個數據進行識別#np.arange():arange函數用於創建等差數組,使用頻

基於pytorch的 儲存載入模型引數

當我們花費大量的精力訓練完網路,下次預測資料時不想再(有時也不必再)訓練一次時,這時候torch.save(),torch.load()就要登場了。 儲存和載入模型引數有兩種方式: 方式一:   torch.save(net.state_dict(),path): 功能

tensorflow將圖片儲存為tfrecordtfrecord的讀取

tensorflow官方提供了3種方法來讀取資料: 預載入資料(preloaded data):在TensorFlow圖中定義常量或變數來儲存所有的資料,適用於資料量不太大的情況。 填充資料(feeding):通過Python產生資料,然後再把資料填充到後端。 從檔案讀取資

【學習筆記】Hands-on ML with sklearn&tensorflow [TF] [1]模型的訓練、儲存載入

本篇內容:一個簡單的預測模型的建立、訓練、儲存和載入。 匯入必要模組: import numpy as np import pandas as pd import tensorflow as tf import ssl #解決資料來源網站簽名認證失敗的問題 from sklearn.data