1. 程式人生 > >focal loss論文筆記(附基於keras的多類別focal loss程式碼)

focal loss論文筆記(附基於keras的多類別focal loss程式碼)

一.focal loss論文

二.focal loss提出的目的

  • 解決one-stage目標檢測是場景下前景和背景極度不平衡的情況(1:1000)
  • 讓模型在訓練的時候更加關注hard examples(前景)。
  • 另外two-stage的檢測器是用一下兩個方法來解決類別不平衡問題的:
    • 提取候選框的過程實際上就消除了很多背景框,因為提取的候選框是大概率包括目標的
    • 在第二個階段訓練的時候,minibatch一般被認為的固定正負樣本的比例,大致是1:3

三.focal loss原理

1.CE(cross entropy) 交叉熵

mark
mark

2.balanced CE

mark

3.focal loss

mark
- 當一個樣本誤分類後,pt接近0,1-pt接近1,則loss無影響;當pt接近與1,則1-pt接近與0,loss的權重變的很小,則該樣本的loss對總的loss貢獻就小了。
- lamda是一個超引數,反應了權值係數的影響程度,在作者的實驗中lamda=2的結果是最好的。

4.focal loss with balanced weight

mark

5.RetinaNet模型框架部分

mark
- backbone部分採用基於resnet的FPN接面構,P3到P7一共5層的金字塔結構
- anchor部分,對於密集的目標場景增加了更多尺度的anchor
- 分類子網路部分,每個金字塔level的子網路引數是共享的,一個子網路包含4個conv層,卷積核大小為3*3;分類子網路和迴歸子網路的結構相同,但是引數是分開的,不像RPN裡面是共享的。
- 網路的初始化,在分類子網路的最後一層卷積部分,偏執初始化為-log((1-pi)/pi);其他的偏執初始化為0,權值初始化為高斯權值,delta取0.01

四.focal loss程式碼

  • 論文中理論敘述的場景是基於二分類,以下為基於多分類的focal loss,同時參考一篇部落格,加入了另外的因子能夠更好的防止過擬合