tensorflow 檢查點和模型,儲存與恢復使用,官方教程(一)
- 檢查點:這種格式依賴於建立模型的程式碼。
- SavedModel:這種格式與建立模型的程式碼無關。
示例程式碼
本文件依賴於 TensorFlow 使用入門中詳細介紹的同一個鳶尾花分類示例。要下載和訪問該示例,請執行下列兩個命令:
git clone https://github.com/tensorflow/models/cd models/samples/core/get_started
本文件中的大部分程式碼段都是在 premade_estimator.py
的基礎上加以細微修改。
儲存經過部分訓練的模型
Estimator 自動將以下內容寫入磁碟:
- 檢查點:訓練期間所建立的模型版本。
- 事件檔案:其中包含 TensorBoard
要指定 Estimator 在其中儲存其資訊的頂級目錄,請為任何 Estimator 的建構函式的可選 model_dir
引數分配一個值。例如,以下程式碼將 model_dir
引數設定為 models/iris
目錄:
classifier = tf.estimator.DNNClassifier(
feature_columns=my_feature_columns,
hidden_units=[10,10],
n_classes=3,
model_dir='models/iris')
假設您呼叫 Estimator 的 train
方法。例如:
classifier .train( input_fn=lambda:train_input_fn(train_x, train_y, batch_size=100), steps=200)
如下圖所示,第一次呼叫 train
會將檢查點和其他檔案新增到 model_dir
目錄中:
要在基於 UNIX 的系統上檢視建立的 model_dir
目錄中的物件,只需按如下方式呼叫 ls
即可:
$ ls -1 models/iris
checkpoint
events.out.tfevents.timestamp.hostname
graph.pbtxt
model.ckpt -1.data-00000-of-00001
model.ckpt-1.index
model.ckpt-1.meta
model.ckpt-200.data-00000-of-00001
model.ckpt-200.index
model.ckpt-200.meta
上述 ls
命令展示了 Estimator 在第 1 步(訓練開始)和第 200 步(訓練結束)建立了檢查點。
預設檢查點目錄
如果您未在 Estimator 的建構函式中指定 model_dir
,則 Estimator 會將檢查點檔案寫入由 Python 的 tempfile.mkdtemp 函式選擇的臨時目錄中。例如,以下 Estimator 建構函式不會指定 model_dir
引數:
classifier = tf.estimator.DNNClassifier( feature_columns=my_feature_columns, hidden_units=[10,10], n_classes=3)print(classifier.model_dir)
tempfile.mkdtemp
函式會根據您的作業系統選擇安全的臨時目錄。例如,Mac 作業系統上的典型臨時目錄可能如下所示:
/var/folders/0s/5q9kfzfj3gx2knj0vj8p68yc00dhcr/T/tmpYm1Rwa
檢查點頻率
預設情況下,Estimator 按照以下時間安排將檢查點儲存到 model_dir
中:
- 每 10 分鐘(600 秒)寫入一個檢查點。
- 在
train
方法開始(第一次迭代)和完成(最後一次迭代)時寫入一個檢查點。 - 只在目錄中保留 5 個最近寫入的檢查點。
您可以通過執行下列步驟來更改預設時間安排:
- 建立一個
RunConfig
物件來定義所需的時間安排。 - 在例項化 Estimator 時,將該
RunConfig
物件傳遞給 Estimator 的config
引數。
例如,以下程式碼將檢查點時間安排更改為每 20 分鐘儲存一次並保留 10 個最近寫入的檢查點:
my_checkpointing_config = tf.estimator.RunConfig( save_checkpoints_secs =20*60,# Save checkpoints every 20 minutes. keep_checkpoint_max =10,# Retain the 10 most recent checkpoints.)classifier = tf.estimator.DNNClassifier( feature_columns=my_feature_columns, hidden_units=[10,10], n_classes=3, model_dir='models/iris', config=my_checkpointing_config)
恢復模型
第一次呼叫 Estimator 的 train
方法時,TensorFlow 會將一個檢查點儲存到 model_dir
中。隨後每次呼叫 Estimator 的 train
、evaluate
或 predict
方法時,都會發生下列情況:
- Estimator 通過執行
model_fn()
構建模型圖。(要詳細瞭解model_fn()
,請參閱建立自定義 Estimator。) - Estimator 根據最近寫入的檢查點中儲存的資料來初始化新模型的權重。
換言之,如下圖所示,一旦存在檢查點,TensorFlow 就會在您每次呼叫 train()
、evaluate()
或 predict()
時重建模型。
避免不當恢復
通過檢查點恢復模型的狀態這一操作僅在模型和檢查點相容時可行。例如,假設您訓練了一個 DNNClassifier
Estimator,它包含 2 個隱藏層且每層都有 10 個節點:
classifier = tf.estimator.DNNClassifier(
feature_columns=feature_columns,
hidden_units=[10,10],
n_classes=3,
model_dir='models/iris')
classifier.train(
input_fn=lambda:train_input_fn(train_x, train_y, batch_size=100),
steps=200)
在訓練之後(因此已在 models/iris
中建立檢查點),假設您將每個隱藏層中的神經元數量從 10 更改為 20,然後嘗試重新訓練模型:
classifier2 = tf.estimator.DNNClassifier( feature_columns=my_feature_columns, hidden_units=[20,20],# Change the number of neurons in the model. n_classes=3, model_dir='models/iris')classifier.train( input_fn=lambda:train_input_fn(train_x, train_y, batch_size=100), steps=200)
由於檢查點中的狀態與 classifier2
中描述的模型不相容,因此重新訓練失敗並出現以下錯誤:
...InvalidArgumentError(see above for traceback): tensor_name =
dnn/hiddenlayer_1/bias/t_0/Adagrad; shape in shape_and_slice spec [10]
does not match the shape stored in checkpoint:[20]
要執行實驗(訓練並比較略微不同的模型版本),請儲存一份建立每個 model_dir
所需的程式碼的副本,同時可能需要為每個版本建立一個單獨的 git 分支。這種區分將有助於保證檢查點的可恢復性。