1. 程式人生 > >【tf.keras】使用手冊

【tf.keras】使用手冊

目錄

  • 0. 簡介
  • 1. 安裝
    • 1.1 安裝 CUDA 和 cuDNN
  • 2. 資料集
    • 2.1 使用 tensorflow_datasets 匯入公共資料集
    • 2.2 資料集過大導致記憶體溢位
    • 2.3 載入 cifar10 資料時報錯
  • 3. 評價指標
    • 3.1 實現 F1 socre、precsion、recall
  • 4. 優化器
    • 4.1 AdamW 優化器示例程式
    • 4.2 tf.keras 1.x 在使用 learning rate decay 時不要使用 tf.train 內的優化器
  • 5. 模型
    • 5.1 模型復現
    • 5.2 載入 AlexNet 預訓練模型
    • 5.3 迴圈訓練模型導致 OOM
  • 6. TensorFlow API 變化
    • 5.1 TF 1.x 到 TF 2.0 API 變化,隨機種子、動態分配視訊記憶體
    • 5.2 TF 2.1 API 變化

0. 簡介

記錄使用 tf.keras 時遇到的各種問題。

tf.keras 是 keras 的未來,keras 作為 TensorFlow 的高階 API,大大簡化 TensorFlow 程式碼的編寫過程。

Keras(單獨的)、TensorFlow 1.x 和 TensorFlow 2.0 的 keras API 變化不大,手冊可以通用。

tf.keras 快速入門:初學者的 TensorFlow 2.0 教程 | TensorFlow Core

1. 安裝

1.1 安裝 CUDA 和 cuDNN

【tf.keras】Linux 非 root 使用者安裝 CUDA 和 cuDNN

2. 資料集

2.1 使用 tensorflow_datasets 匯入公共資料集

【tf.keras】tensorflow datasets,tfds

2.2 資料集過大導致記憶體溢位

【tf.keras】在 cifar 上訓練 AlexNet,資料集過大導致 OOM

2.3 載入 cifar10 資料時報錯

【tf.keras】ssl.SSLError: [SSL: DECRYPTION_FAILED_OR_BAD_RECORD_MAC] decryption failed or bad record mac (_ssl.c:1977)

3. 評價指標

3.1 實現 F1 socre、precsion、recall

在整個資料集而不是單個 batch 上實現 F1 socre、precsion、recall 等評價指標:
【tf.keras】實現 F1 score、precision、recall 等 metric

4. 優化器

4.1 AdamW 優化器示例程式

【tf.keras】AdamW: Adam with Weight decay

4.2 tf.keras 1.x 在使用 learning rate decay 時不要使用 tf.train 內的優化器

【tf.keras】tf.keras使用tensorflow中定義的optimizer

5. 模型

5.1 模型復現

【tf.keras】tf.keras模型復現
(注意:在CPU上訓練才能完全復現模型)

5.2 載入 AlexNet 預訓練模型

【tf.keras】tf.keras載入AlexNet預訓練模型

5.3 迴圈訓練模型導致 OOM

【tf.keras】Resource exhausted: OOM when allocating tensor with shape [9216,4096] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc

6. TensorFlow API 變化

5.1 TF 1.x 到 TF 2.0 API 變化,隨機種子、動態分配視訊記憶體

【tf.keras】TensorFlow 1.x 到 2.0 的 API 變化

5.2 TF 2.1 API 變化

TensorFlow 2.1 將 fit_generator(), evaluate_generator(), predict_generator() 等函式分別合併到 fit(),evaluate(),predict()