第9篇 Fast AI深度學習課程——多目標識別與定位
一、一個模型同時實現單目標識別與定位
在上一節中,我們先構建了一個分類網路,用於圖片中最大目標的類別劃分;然後構建了一個用於輸出目標座標的網路。我們尚未將兩個網路聯絡起來。但事實上,兩個網路的架構十分相似(都是基於resnet34
)。那麼能否去除這種冗餘,使用一個網路同時實現目標分類與定位呢?本部分將按照:準備資料—構建網路—定義優化目標這一分解步驟,來展示針對應用場景進行建模的通用流程。
1. 準備資料
資料分為自變數和因變數兩部分,自變數自然就是圖片了。無論是分類還是定位,在構建網路時針對圖片所做的操作,均可通用,因此,這一部分不用考慮。而對於因變數,需要將目標類別和定位座標結合在一起。但兩者一個是連續型的,一個是離散型的;因此,在生成資料檔案CSV
dataset
拼接起來(dataset
實際上即為儲存資料的地方)。拼接方法是:在獲取資料時(即呼叫__getitem__()
函式時),同時返回角點座標和類別標籤。
md = ImageClassifierData.from_csv(PATH, JPEGS, BB_CSV, tfms=tfms,
continuous=True, val_idxs=val_idxs)
md2 = ImageClassifierData.from_csv(PATH, JPEGS, CSV,
tfms=tfms_from_model(f_model, sz))
class ConcatLblDataset(Dataset):
def __init__(self, ds, y2): self.ds,self.y2 = ds,y2
def __len__(self): return len(self.ds)
def __getitem__(self, i):
x,y = self.ds[i]
return (x, (y,self.y2[i]))
md. trn_dl.dataset = ConcatLblDataset(md.trn_ds, md2.trn_y)
md.val_dl.dataset = ConcatLblDataset(md.val_ds, md2.val_y)
事實上,ConcatLblDataset()
的第二個引數為一個可迭代物件,且和md
的資料的檔案索引相對應即可。
後面可考慮繼承DataSet類。
2. 網路架構
在分類網路與迴歸網路的共有部分的基礎上,再新增附加層以輸出分類和定位所需的數值:共需要4+c
個輸出,其中c
為類別數目。
head_reg4 = nn.Sequential(
Flatten(),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(25088,256),
nn.ReLU(),
nn.BatchNorm1d(256),
nn.Dropout(0.5),
nn.Linear(256,4+len(cats)),
)
models = ConvnetBuilder(f_model, 0, 0, 0, custom_head=head_reg4)
其中ConvenetBuilder
中取值為0的3個引數分別表示:全連線層的節點數目、是否為多分類、是否為迴歸問題,但在設定custom_head
後不起作用。
3. 損失函式
將定位網路的L1
範數誤差和分類網路的交叉熵加權求和,即得到所需的損失函式。其中要點如下:
- 損失函式為接受
input
和target
引數、返回一個數值的函式。其中input
即每個資料塊經過網路的前向傳播後所得的結果,target
即為每個資料塊的y
值(角點值,類別標籤)。 - 損失函式中的資料大部分為
torch.Variable
型別,以用於梯度計算。 - 通過設定學習器的
crit
域來設定損失函式,其接受一個函式。通過設定學習器的metrics
來顯示訓練過程中的指標,其接受一個函式列表。
後續訓練過程就沒啥新鮮東西了。
二、多目標的識別與定位
我們已經得到了能夠同時進行目標分類和定位的網路,考慮將之擴充套件為多目標分類與定位。思路是輸出固定目標數(課程中設定的是16
)的資訊:16x(4+c)
。有兩種方式:
- 修改單目標網路的輸出層,使之直接輸出
16x(4+c)
的數值。這一方法最初由YOLO
(You Only Look Once
)網路使用。 - 在
resnet34
後接一個跨立度為2
的卷積層,使其輸出為4x4x(4+c)
(resnet34
的最後一層輸出為7x7x512
)。這一方法最初由SSD
(Single Shot Detector
)使用。
1. 資料準備
將目標的座標資料整理為多類別分類網路所需的CSV
檔案格式。
md = ImageClassifierData.from_csv(PATH, JPEGS, MBB_CSV, tfms=tfms, bs=bs, continuous=True, num_workers=4)
繼而將md.trn_dl.dataset
以及md.val_dl.dataset
與類別資料進行拼接。需要注意的是:不同檔案中所含目標個數不同,md
採取的策略是按同批次的影象中目標個數的最大值進行補齊(這意味著不同批次的向量長度會有所變化。pascal VOC
資料中,007953.jpg
包含19
輛摩托車,是目標數目最大的圖片)。
2. 構建網路
按照SSD
方法,構建附加層。由於需要預測的資訊略多,可新增微網路以增強模型的描述功能。最終網路輸出兩組預測值:一組的尺寸為batch_sizex16x(1+c)
,用於目標類別的判定;一組尺寸為batch_sizex16x4
,用於目標定位。
3. 損失函式
考慮卷積結果的接觸域。(卷積結果中的一個元素實際上是由原影象中的部分元素的值決定的,這些原影象中的元素的分佈區域即為卷積結果中對應元素的接觸域。)所得輸出為4x4x(4+c+1)
,即可認為將原影象分為了4x4
部分,網路輸出結果中的每條特徵(4+c+1
維的向量),是對原影象中的某一塊的描述。
首先考慮分類。如何確定原影象中的某一塊屬於哪一類呢?定義影象的某塊與目標的重疊率為重疊區域面積與二者面積之和的比。
在知曉影象中的各個小塊與目標的關係後,就可按照單目標分類與定位的損失函式,對每一小塊計算損失,然後求平均。這裡需要注意的有如下幾點:
-
每一條特徵生成的定位框的座標,限定在其所對應的格點的附近範圍內。這樣就需要對影象網格化進行多樣化處理,以提供更強的描述能力。之所以採用這種方法的考慮如下:由於不確定格點對應的影象網格與目標的關聯程度,若使用某格點在全圖範圍內預測整個目標的定位框,可能需要引入重疊框的加權問題。
-
在求分類問題的損失函式時,採用的是二值熵函式,同時去除了背景項。考慮如下:各個網格構成了一個小的樣本,但這裡有個問題就是,這些樣本中大部分可能都是背景,也就是說這是個不均衡的樣本。去除背景類後,使用二值熵去判定各個網格是不是某類,更合適。
三、多目標的識別與定位的優化
1. 提供多樣化的網格
- 繼續進行跨立度為
2
的卷積,提供不同尺寸的影象網格; - 對網格進行縮放;
- 對網格進行拉伸。
2. 改進損失函式
如前所述,由於單一圖片的樣本不均衡性,導致在不確定某一小塊究竟是啥時,將之判定為背景總是最安全的。這會導致目標區域在圖片中較小時,網路認為圖中無目標。如下圖中的中間兩幅圖所示。一個解決方法是使用Focal Loss
:
其中 為二值熵函式。(具體為啥能改進,後面搞懂了再說吧~~)
圖 4. 被網路忽視目標的示例圖片3. 去除重疊窗
若兩個框框分屬同類,又有很高的重疊度,則將兩個框融合。
附註
1. 查詢Dataloader
返回資料的檔名
生成模型所需資料md
後,可通過next(iter(md.val_dl))
獲取一組資料。其返回值為影象陣列和影象標籤。怎麼找到這些資料對應的檔名呢?
先去看md
,其是ImageClassfierData
生成的,找到其定義處(在fastai/dataset.py
檔案中),發現其繼承自ImageData
,爺爺是ModelData
,但找完了它們的變數,沒有發現和影象檔名相關的。那就繼續找val_dl
,其是DataLoader
類,並發現val_ds
是由val_dl.dataset
返回的。而獲取一組資料時,呼叫的是DataLoader
的__iter__()
方法,該方法中顯示了從資料集生成Batch
時,呼叫了self.get_batch()
方法,該方法使用了抽樣器self.batch_sampler
。而所抽取的樣本都儲存在DataLoader.dataset
變數中。使用val_dl.dataset.__dict__
檢視其變數,發現其有fnames
欄位,儲存的應該是檔名。然後檢視val_dl.batch_sampler
,發現其是一個迭代器,由其產生索引,則可從fnames
中獲取檔名。
2. 近年來目標識別的發展歷程
圖 5. 目標識別方法演變脈絡一些有用的連結
- 課程wiki : 本節課程的一些相關資源,包括課程筆記、課上提到的部落格地址等。
- Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks :
Faster R-CNN
論文。 - Scalable Object Detection using Deep Neural Networks。
- You Only Look Once: Unified, Real-Time Object Detection:
YOLO
論文。 - SSD: Single Shot MultiBox Detector:
SSD
論文。 - Focal Loss for Dense Object Detection (RetinaNet):
Focal
論文。