1. 程式人生 > >深度學習與計算機視覺(PB-04)-rank-N準確度

深度學習與計算機視覺(PB-04)-rank-N準確度

在我們深入討論高階深度學習主題(如遷移學習)之前,先來了解下rank-1、rank-5和rank-N準確度的概念。當你在閱讀深度學習相關文獻時,尤其是關於計算機視覺和影象分類,你很可能會看到關於rank-N 準確度。例如,幾乎所有在ImageNet資料集上驗證的機器學習方法的論文都給出了rank-1和rank-5準確度 (我們將在本章後面解釋為什麼需要使用rank-1和rank-5準確度).

rank-N準確度指標與傳統的評估指標有何不同呢?在本節中,我們將討論rank-N準確度內容以及如何實現它。最後將其應用於在Flower-17和CALTECH-101資料集上。

rank-N準確度

通過一個例子來解釋rank-N準確度概念。假設我們正在評估一個訓練在CIFAR-10資料集上的神經網路模型,CIFAR-10資料集包括10個類:飛機,汽車,鳥、貓、鹿、狗、青蛙、馬、船和卡車。給定一張輸入影象(如圖4.1左)

圖4.1 左:青蛙, 右:汽車

模型返回的結果是表4.1左的類標籤概率資訊。

表4.1 左:圖4.1左圖預測結果, 右:圖4.1右圖預測結果

我們先看看rank-1的計算,對於每一張圖片,取模型預測的類概率列表中最大的概率對應的標籤作為該圖片的預測結果。比如,我們使用圖4.1左對應真實標籤為青蛙的圖片進行預測,得到表4.1左結果,從中可以看到最大概率為97.3%對應的預測結果也為青蛙,說明預測結果是對的。因此,可以看到計算rank-1的整個過程為:

  • 步驟1:計算資料集中每個輸入影象的類標籤概率。
  • 步驟2:原始標籤與對應概率最大的標籤進行比較,若相同為true,反之false
  • 步驟3:統計步驟2為true的個數

上面我們計算的是rank-1準確度,即對應預測最高概率的標籤與真實標籤相同的個數佔總個數的百分比——標籤相同的個數 / 總資料個數。

現在,我們擴充套件到rank-5準確度,我們關注的不是top1的預測,而是top5的預測,那麼整個計算過程如下:

  • 步驟1: 計算資料集中每個輸入影象的類標籤概率。
  • 步驟2: 對預測的類標籤概率進行降序排序
  • 步驟3: 判斷真實的標籤是否落在預測的top5標籤裡面,若存在,則標記為true,反之false
  • 步驟4: 統計步驟3中為true的個數

rank-5準確度是rank-1準確度的擴充套件,我們對一張圖片的預測結果是來自模型輸出結果中top5對應的5個預測,而不是top1的1個預測。例如,我們對圖4.1右圖片進行預測,rank-5對應的預測結果為表4.1右結果。

很顯然圖4.1右是一輛汽車,然而,如果使用的是rank-1預測的話,結果為卡車,顯然是不對的。但是如果使用rank-5的話,發現汽車實際上是第2個預測結果,這時候對於rank-5預測而言是正確的。這種方法也可以很容易地推廣到計算rank-N準確度。 一般而言,我們只計算rank-1和rank-5準確度——計算rank-1的準確度可以理解,為什麼還需要計算rank-5準確度呢?

對於CIFAR-10資料集來說,由於本身類別個數不多,計算rank-5準確度有點不太合適。但對於大型的、具有挑戰性的資料集來說,特別是細粒度的分類。從Szegedy[17]等人的論文中的一個例子或許可以很好的解釋為什麼需要計算rank-1和rank-5準確度。比如圖4.2中,我們可以看到左邊是西伯利亞哈士奇,右邊是愛斯基摩犬。從人的肉眼來看是無法區分開的,但是這個在ImageNet 資料集中是有效的標籤。

圖4.2,左:西伯利亞哈士奇,右: 愛斯基摩犬

當處理的大型資料集各個類別之間存在許多具有相似特徵時,我們往往會增加一個rank-5準確度,也就是說我們不止關心rank-1準確度,也關心rank-5準確度。結合兩個準確度來以衡量神經網路的效能。理想情況下,隨著預測資料增加,希望rank-1準確度和rank-5準確度同比例增加。但是,在某些資料集上,情況往往並非總是如此。

因此,我們也根據rank-5準確度檢驗模型,以確保我們的網路在後面的迭代中仍然是“學習”的。在訓練快結束時,rank-1準確度可能會停滯不前,但是當我們的網路學習到更多的識別特徵(雖然沒有足夠的識別能力超過top1的預測)時,rank-5準確度會繼續提高。

實現rank-1和rank-5準確度

我們可以通過在專案中構建一個工具模組來計算rank-1和rank-5準確度。因此,在pyimagesearch
專案中增加一個子模組utils,並在子模組中增加一個ranked.py指令碼,整個目錄結構如下:

--- pyimagesearch
|    |--- __init__.py
|    |--- callbacks
|    |--- io
|    |--- nn
|    |--- preprocessing
|    |--- utils
|        |--- __init__.py
|        |--- captchahelper.py
|        |--- ranked.py

開啟ranked.py指令碼,寫入以下程式碼:

#encoding:utf-8
import numpy as np
def rank5_accuracy(preds,labels):
    #初始化
    rank1 = 0
    rank5 = 0

定義了rank5_accuracy函式,主要需要傳入兩個引數:

  • preds: 一個NxT的矩陣,其中N表示行數,T表示列數,每個值代表對應標籤下的概率
  • labels: 原始資料中的真實標籤

接下來計算rank-1和rank-5:

    # 遍歷資料集
    for (p,gt) in zip(preds,labels):
        # 通過降序對概率進行排序
        p = np.argsort(p)[::-1]
        # 檢查真實標籤是否落在top5中
        if gt in p[:5]:
            rank5 += 1
        # 檢驗真實標籤是否等於top1
        if gt == p[0]:
            rank1 += 1
            # 計算準確度
    rank1 /= float(len(labels))
    rank5 /= float(len(labels))
    return rank1,rank5

應用

第2節中,我們使用了預先訓練好的VGG16模型對三種資料集提取了特徵,並對特徵向量訓練了邏輯迴歸模型,以及對模型進行了評估,接下來,我們將使用rank-1和rank-5準確度進行型評估。

新建一個指令碼檔案,名為rank_accuracy.py,並寫入以下程式碼:

#encoding:utf-8
from pyimagesearch.utils.ranked import rank5_accuracy
import argparse
import pickle
import h5py

接下來,解析命令列引數:

# 解析命令列引數
ap = argparse.ArgumentParser()
ap.add_argument('-d','--db',required=True,help='path HDF5 databases')
ap.add_argument('-m','--model',required=True,help = 'path to pre-trained model')
args = vars(ap.parse_args())

主要有兩個引數:

  • –db: HDF5資料路徑
  • –model:之前訓練好的logistic regression模型路徑

由於我們使用的是前75%的資料進行訓練,因此,我們使用後25%資料進行預測和評估:

# 載入模型
print("[INFO] loading pre-trained model...")
model = pickle.loads(open(args['model'],'rb').read())

db = h5py.File(args['db'],'r')
i = int(db['labels'].shape[0] * 0.75)
# 預測
print ("[INFO] predicting....")
preds = model.predict_proba(db['features'][i:])
(rank1,rank5) = rank5_accuracy(preds,db['labels'][i:])
# 結果列印
print("[INFO] rank-1:{:.2f}%".format(rank1 * 100))
print("[INFO] rank-5:{:.2f}%".format(rank5 * 100))
db.close()

Flowers-17結果

下面我們使用Flowers-17資料進行實驗,執行下面命令:

$ python rank_accuracy.py --db youPath/data/flowers17/hdf5/features.hdf5 -model youPath/flowers17.cpickle

將得到如下結果:

[INFO] loading pre-trained model...
[INFO] predicting....
[INFO] rank-1:90.00%
[INFO] rank-5:99.71%

CALTECH-101結果

我們嘗試另外一個數據例子—CALTECH-101,執行下面程式碼:

$ python rank_accuracy.py --db youPath/data/caltech101/hdf5/features.hdf5 --model youPath/caltech101.cpickle

得到的結果如下;

[INFO] loading pre-trained model...
[INFO] predicting...
[INFO] rank-1: 95.58%
[INFO] rank-5: 99.45%

總結

在本節中,我們討論了rank-1和rank-5準確度概念。在大型的、具有挑戰性的資料集(如ImageNet)上,除了要關注rank-1準確度,還需要關注rank-5準確度,在這些資料集中,即使是人眼檢視也無法正確地給每一張影象貼上真實的標籤。在這種情況下,如果真實標籤存在於top5預測中,那麼可以認為我們的模型的預測是“正確的”。

說明:rank-1和rank-5準確性並不僅限於深度學習和影象分類,還可以使用在其它領域。

詳細程式碼位置:github