1. 程式人生 > >21個專案玩轉深度學習:基於TensorFlow的實踐詳解03—打造自己的影象識別模型

21個專案玩轉深度學習:基於TensorFlow的實踐詳解03—打造自己的影象識別模型

書籍原始碼:https://github.com/hzy46/Deep-Learning-21-Examples

CNN的發展已經很多了,ImageNet引發的一系列方法,LeNet,GoogLeNet,VGGNet,ResNet每個方法都有很多版本的衍生,tensorflow中帶有封裝好各方法和網路的函式,只要餵食自己的訓練集就可以完成自己的模型,感覺超方便!!!激動!!!因為雖然原理流程瞭解了,但是要寫出來真的。。。。好難,臣妾做不到啊~~~~~~~~

START~~~~

1.資料準備

首先了解下微調的概念: 以VGG為例

他的結構是卷積+全連線,卷積層分為5個部分共13層,conv1~conv5。還有三層全連線,即fc6,fc7,fc8。總共16層,因此被稱為VGG16。

a.如果要將VGG16的結構用於一個新的資料集,首先要去掉fc8,因為fc8原本的輸出是1000類的概率。需要改為符合自身訓練集的輸出類別數。

b.訓練的時候,網路的引數的初始值並不是隨機化生成的,而是採用VGG16在ImageNet上已經訓練好的引數作為訓練的初始值。因為已經訓練過的VGG16中的引數已經包含了大量有用的卷積過濾器,這樣做不僅節約大量訓練時間,而且有助於分類器效能的提高。

載入VGG16的引數後,即可開始訓練。此時需要指定訓練層數的範圍。一般而言,可以選擇以下幾種:

  • 只訓練fc8:訓練範圍一定要包含fc8這一層。這樣的選擇一般效能都不會太好,但速度很快;因為他只訓練fc8,保持其他層的引數不動,相當於把VGG16當成一個“特徵提取器”,用fc7層提取的特徵做一個softmax的模型分類。
  • 訓練所有引數:耗時較慢,但能取得較高效能。
  • 訓練部分引數:通常是固定淺層引數不變,訓練深層引數。如固定conv1、conv2部分的引數不訓練,只訓練conv3、conv4、conv5、fc6、fc7、fc8的引數。

這種訓練方法就是對神經網路做微調。

1.1 切分train&test

書中提供了衛星影象資料集,有6個類別,分別是森林(wood),水域(water),岩石(rock),農田(wetland),冰川(glacier),城市區域(urban)

儲存結構為data_prepare/pic,再下層有兩個資料夾train和validation,各資料夾下有6個資料夾,放的是該類別下的圖片。

1.2 轉換成tfrecord格式

python data_convert.py -t pic/ \
  --train-shards 2 \
  --validation-shards 2 \
  --num-threads 2 \
  --dataset-name satellite

引數解釋:

-t pic/ :表示轉換pic資料夾下的資料,該資料夾必須與上面的檔案結構保持一致

--train-shards 2 :把訓練集分成兩塊,即最後的訓練資料就是兩個tfrecord格式的檔案。若資料集更大,可以分更多資料塊

--validation-shards 2 :把驗證集分成兩塊

--num-thread 2 :用兩個執行緒來產生資料。注意執行緒數必須要能整除train-shards和validation-shards,來保證每個執行緒處理的資料塊是相同的。

--dataset-name :給生成的資料集起個名字,即表示最後生成檔案的開頭是satellite_train和satellite_validation

執行上述命令後,就可以在 pic 資料夾中找到 5 個新生成的檔案 ,分別是:

  • 訓練資料 satellite_train_00000-of-00002.tfrecord、satellite_train_00001-of-00002.tfrecord,
  • 驗證資料 satellite_validation_00000-of-00002.tfrecord、satellite_validation_00001-of-00002.tfrecord。
  • label.txt 它表示圖片的內部標籤(數字)到真實類別(字串)之間的對映順序 。 如圖片在 tfrecord 中的標籤為 0 ,那麼就對應 label.txt 第一行的類別,在 tfrecord的標籤為1,就對應 label.txt 中第二行的類別,依此類推。

2.訓練模型

2.1 TensorFlow Slim

Google 公司公佈的一個影象分類工具包,它不僅定義了一些方便的介面,還提供了很多 ImageNet 資料集上常用的網路結構和預訓練模型

截至2017年7月,Slim 提供包括 VGG16VGG19InceptioV1 ~ V4、ResNet 50、ResNet 101、MobileNet 在內大多數常用模型的結構及預訓練模型,更多的模型還會被持續新增進來

原始碼地址: https://github.com/tensorflow/models/tree/master/research/slim

可以通過執行  git clone https://github.corn/tensorflow/models.git  來獲取

 

2.2 定義新的datasets檔案<修改slim原始碼>

 

 

2.3 準備訓練資料夾

 

 

2.4 開始訓練

 

 

3.驗證準確率

 

 

4.匯出模型並對單張圖片分類

 

THE END~~~~