【tf.keras】使用手冊
目錄
- 0. 簡介
- 1. 安裝
- 1.1 安裝 CUDA 和 cuDNN
- 2. 資料集
- 2.1 使用 tensorflow_datasets 匯入公共資料集
- 2.2 資料集過大導致記憶體溢位
- 2.3 載入 cifar10 資料時報錯
- 3. 評價指標
- 3.1 實現 F1 socre、precsion、recall
- 4. 優化器
- 4.1 AdamW 優化器示例程式
- 4.2 tf.keras 1.x 在使用 learning rate decay 時不要使用 tf.train 內的優化器
- 5. 模型
- 5.1 模型復現
- 5.2 載入 AlexNet 預訓練模型
- 5.3 迴圈訓練模型導致 OOM
- 6. TensorFlow API 變化
- 5.1 TF 1.x 到 TF 2.0 API 變化,隨機種子、動態分配視訊記憶體
- 5.2 TF 2.1 API 變化
0. 簡介
記錄使用 tf.keras 時遇到的各種問題。
tf.keras 是 keras 的未來,keras 作為 TensorFlow 的高階 API,大大簡化 TensorFlow 程式碼的編寫過程。
Keras(單獨的)、TensorFlow 1.x 和 TensorFlow 2.0 的 keras API 變化不大,手冊可以通用。
tf.keras 快速入門:初學者的 TensorFlow 2.0 教程 | TensorFlow Core
1. 安裝
1.1 安裝 CUDA 和 cuDNN
【tf.keras】Linux 非 root 使用者安裝 CUDA 和 cuDNN
2. 資料集
2.1 使用 tensorflow_datasets 匯入公共資料集
【tf.keras】tensorflow datasets,tfds
2.2 資料集過大導致記憶體溢位
【tf.keras】在 cifar 上訓練 AlexNet,資料集過大導致 OOM
2.3 載入 cifar10 資料時報錯
【tf.keras】ssl.SSLError: [SSL: DECRYPTION_FAILED_OR_BAD_RECORD_MAC] decryption failed or bad record mac (_ssl.c:1977)
3. 評價指標
3.1 實現 F1 socre、precsion、recall
在整個資料集而不是單個 batch 上實現 F1 socre、precsion、recall 等評價指標:
【tf.keras】實現 F1 score、precision、recall 等 metric
4. 優化器
4.1 AdamW 優化器示例程式
【tf.keras】AdamW: Adam with Weight decay
4.2 tf.keras 1.x 在使用 learning rate decay 時不要使用 tf.train 內的優化器
【tf.keras】tf.keras使用tensorflow中定義的optimizer
5. 模型
5.1 模型復現
【tf.keras】tf.keras模型復現
(注意:在CPU上訓練才能完全復現模型)
5.2 載入 AlexNet 預訓練模型
【tf.keras】tf.keras載入AlexNet預訓練模型
5.3 迴圈訓練模型導致 OOM
【tf.keras】Resource exhausted: OOM when allocating tensor with shape [9216,4096] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc
6. TensorFlow API 變化
5.1 TF 1.x 到 TF 2.0 API 變化,隨機種子、動態分配視訊記憶體
【tf.keras】TensorFlow 1.x 到 2.0 的 API 變化
5.2 TF 2.1 API 變化
TensorFlow 2.1 將 fit_generator(), evaluate_generator(), predict_generator() 等函式分別合併到 fit(),evaluate(),predict()