1. 程式人生 > >tensorflow 檢查點和模型,儲存與恢復使用,官方教程(一)

tensorflow 檢查點和模型,儲存與恢復使用,官方教程(一)

  • 檢查點:這種格式依賴於建立模型的程式碼。
  • SavedModel:這種格式與建立模型的程式碼無關。

示例程式碼

本文件依賴於 TensorFlow 使用入門中詳細介紹的同一個鳶尾花分類示例。要下載和訪問該示例,請執行下列兩個命令:

git clone https://github.com/tensorflow/models/cd models/samples/core/get_started

本文件中的大部分程式碼段都是在 premade_estimator.py 的基礎上加以細微修改。

儲存經過部分訓練的模型

Estimator 自動將以下內容寫入磁碟:

  • 檢查點:訓練期間所建立的模型版本。
  • 事件檔案:其中包含 TensorBoard
     用於建立視覺化圖表的資訊。

要指定 Estimator 在其中儲存其資訊的頂級目錄,請為任何 Estimator 的建構函式的可選 model_dir 引數分配一個值。例如,以下程式碼將 model_dir 引數設定為 models/iris 目錄:

classifier = tf.estimator.DNNClassifier(
    feature_columns
=my_feature_columns,
    hidden_units
=[10,10],
    n_classes
=3,
    model_dir
='models/iris')

假設您呼叫 Estimator 的 train 方法。例如:

classifier
.train(        input_fn=lambda:train_input_fn(train_x, train_y, batch_size=100),                steps=200)

如下圖所示,第一次呼叫 train 會將檢查點和其他檔案新增到 model_dir 目錄中:

第一次呼叫 train()。

要在基於 UNIX 的系統上檢視建立的 model_dir 目錄中的物件,只需按如下方式呼叫 ls 即可:

$ ls -1 models/iris
checkpoint
events
.out.tfevents.timestamp.hostname
graph
.pbtxt
model
.ckpt
-1.data-00000-of-00001
model
.ckpt-1.index
model
.ckpt-1.meta
model
.ckpt-200.data-00000-of-00001
model
.ckpt-200.index
model
.ckpt-200.meta

上述 ls 命令展示了 Estimator 在第 1 步(訓練開始)和第 200 步(訓練結束)建立了檢查點。

預設檢查點目錄

如果您未在 Estimator 的建構函式中指定 model_dir,則 Estimator 會將檢查點檔案寫入由 Python 的 tempfile.mkdtemp 函式選擇的臨時目錄中。例如,以下 Estimator 建構函式不會指定 model_dir 引數:

classifier = tf.estimator.DNNClassifier(    feature_columns=my_feature_columns,    hidden_units=[10,10],    n_classes=3)print(classifier.model_dir)

tempfile.mkdtemp 函式會根據您的作業系統選擇安全的臨時目錄。例如,Mac 作業系統上的典型臨時目錄可能如下所示:

/var/folders/0s/5q9kfzfj3gx2knj0vj8p68yc00dhcr/T/tmpYm1Rwa

檢查點頻率

預設情況下,Estimator 按照以下時間安排將檢查點儲存到 model_dir 中:

  • 每 10 分鐘(600 秒)寫入一個檢查點。
  • 在 train 方法開始(第一次迭代)和完成(最後一次迭代)時寫入一個檢查點。
  • 只在目錄中保留 5 個最近寫入的檢查點。

您可以通過執行下列步驟來更改預設時間安排:

  1. 建立一個 RunConfig 物件來定義所需的時間安排。
  2. 在例項化 Estimator 時,將該 RunConfig 物件傳遞給 Estimator 的 config 引數。

例如,以下程式碼將檢查點時間安排更改為每 20 分鐘儲存一次並保留 10 個最近寫入的檢查點:

my_checkpointing_config = tf.estimator.RunConfig(    save_checkpoints_secs =20*60,# Save checkpoints every 20 minutes.    keep_checkpoint_max =10,# Retain the 10 most recent checkpoints.)classifier = tf.estimator.DNNClassifier(    feature_columns=my_feature_columns,    hidden_units=[10,10],    n_classes=3,    model_dir='models/iris',    config=my_checkpointing_config)

恢復模型

第一次呼叫 Estimator 的 train 方法時,TensorFlow 會將一個檢查點儲存到 model_dir 中。隨後每次呼叫 Estimator 的 trainevaluate 或 predict 方法時,都會發生下列情況:

  1. Estimator 通過執行 model_fn() 構建模型。(要詳細瞭解 model_fn(),請參閱建立自定義 Estimator。)
  2. Estimator 根據最近寫入的檢查點中儲存的資料來初始化新模型的權重。

換言之,如下圖所示,一旦存在檢查點,TensorFlow 就會在您每次呼叫 train()evaluate() 或 predict() 時重建模型。

對 train()、evaluate() 或 predict() 的後續呼叫

避免不當恢復

通過檢查點恢復模型的狀態這一操作僅在模型和檢查點相容時可行。例如,假設您訓練了一個 DNNClassifier Estimator,它包含 2 個隱藏層且每層都有 10 個節點:

classifier = tf.estimator.DNNClassifier(
    feature_columns
=feature_columns,
    hidden_units
=[10,10],
    n_classes
=3,
    model_dir
='models/iris')

classifier
.train(
    input_fn
=lambda:train_input_fn(train_x, train_y, batch_size=100),
        steps
=200)

在訓練之後(因此已在 models/iris 中建立檢查點),假設您將每個隱藏層中的神經元數量從 10 更改為 20,然後嘗試重新訓練模型:

classifier2 = tf.estimator.DNNClassifier(    feature_columns=my_feature_columns,    hidden_units=[20,20],# Change the number of neurons in the model.    n_classes=3,    model_dir='models/iris')classifier.train(    input_fn=lambda:train_input_fn(train_x, train_y, batch_size=100),        steps=200)

由於檢查點中的狀態與 classifier2 中描述的模型不相容,因此重新訓練失敗並出現以下錯誤:

...InvalidArgumentError(see above for traceback): tensor_name =
dnn
/hiddenlayer_1/bias/t_0/Adagrad; shape in shape_and_slice spec [10]
does
not match the shape stored in checkpoint:[20]

要執行實驗(訓練並比較略微不同的模型版本),請儲存一份建立每個 model_dir 所需的程式碼的副本,同時可能需要為每個版本建立一個單獨的 git 分支。這種區分將有助於保證檢查點的可恢復性。