1. 程式人生 > >MTCNN(一)概覽與執行

MTCNN(一)概覽與執行

一、相關知識

三個CNN串聯

P-Net:(12-Net)Fast proposal Network,找出備選框。

R-Net:(24-Net)Refinement Network,對備選框進行重新選擇   

O-Net:(48-Net)Output Network,輸出最終的box與最終的landmark position

二、訓練與執行 

2.1 背景知識

tfrecords

2.2 程式作用

gen_12net_data.py     訓練PNet的資料的取樣程式碼gen_hard_example.py     分別生成RNet和ONet的訓練資料gen_imglist_xxnet.py

    分別將三個網路的三個任務(分類,迴歸,特徵點檢測)的資料彙總到一個檔案中gen_xx_tfrecords.py     分別生成3個網路的tfrecord,在這裡需要注意:         PNet的訓練資料(pos,neg,part,landmark)是混在一起的,生成了一個tfrecord     RNet和ONet的各自需要生成4個tfrecord(pos,neg,part,landmark),因為要控制各部分的樣本比例(1:3:1:1)loader.py     迭代器,用於讀取圖片read_tfrecord_v2.py/tfrecord_utils.py     用於讀取tfrecord資料,並對其解析utils.py
    用於一些資料處理操作gen_landmark_tfrecords_aug_xx.py     用於生成特徵點的資料,在這裡並沒有生成tfreord,只是對進行資料增強(隨機映象、隨機旋轉)     此指令碼的輸入是trainImageList.txt,其中定義了檔案的路徑,人臉框的位置(x1,x2,y1,y2),特徵點的位置(x1,y1,,,,,x5,y5)BBox_utils.py/landmark_utils.py     用於特徵點處理

2.3 命令列

訓練流程:

資料集下載,格式轉換。

以下每一步都要轉換目錄。訓練為根目錄,生成樣本為/prepare_data目錄

生成12net的訓練資料,把12net的訓練資料生成tfrecords檔案

python gen_shuffle_data.py 12
python gen_tfdata_12net.py 

訓練12net(P-Net)

python src/mtcnn_pnet_test.py

根據訓練結果生成12net輸出的難樣本,生成24net隨機樣本併合並這兩個生成的樣本

python tf_gen_12net_hard_example.py
python gen_shuffle_data.py 24
python gen_tfdata_24net.py

訓練24net(R-Net)

python src/mtcnn_rnet_test.py

根據生成的網路生成24net輸出的難樣本,生成48net隨機樣本併合並這兩個生成的樣本

python tf_gen_24net_hard_example.py
python gen_shuffle_data.py 48
python gen_tfdata_48net.py

訓練48-Net(O-Net)

python src/mtcnn_onet_test.py

2.4 執行程式

python test_img.py (IMAGE_PATH) --model_dir ./save_model/all_in_one
python test_img.py images/0_1_f_36.jpg --model_dir ./save_model/new_saver