1. 程式人生 > >[深度學習]Semantic Segmentation語義分割之SegNet(3)

[深度學習]Semantic Segmentation語義分割之SegNet(3)

論文全稱:《SegNet: A Deep Convolutional Encoder-Decoder Architecture for Image Segmentation》

論文地址:https://arxiv.org/abs/1511.00561

論文程式碼:

python pytorch版本https://github.com/delta-onera/segnet_pytorch

python TensorFlow版本https://github.com/tkuanlun350/Tensorflow-SegNet

C++ caffe版本https://github.com/alexgkendall/caffe-segnet

 

論文demo:http://mi.eng.cam.ac.uk/projects/segnet/demo.php#demo

目錄

論文綜述

網路結構

Max pooling 索引

程式碼詳解


論文綜述

SegNet的新穎之處在於解碼器對其低解析度輸入特徵圖進行上取樣的方式。具體來說,解碼器使用了在對應編碼器的max-pooling步驟中所計算出的pooling索引來執行非線性上取樣。這就消除了學習upsample的需要。

SegNet主要由場景理解應用程式驅動。因此,它被設計成在推理過程中在記憶體和計算時間方面都是有效的。與其他競爭體系結構相比,它的可訓練引數數量也要少得多,並且可以使用隨機梯度下降法進行端到端訓練。

在解碼器中重複使用max-pooling 的索引的好處有:

  1. 增強了輪廓邊界的描繪
  2. 減少了訓練的引數
  3. 可以被應用與任意的編碼-解碼的結構中

網路結構

SegNet有一個編碼器網路和相應的解碼器網路,然後是最後的畫素級分類層。編碼器網路由13個卷積層組成,對應於vgg16網路中的前13個卷積層。丟棄完全連線的層,以便在最深的編碼器輸出端保留更高解析度的特徵圖。這也大大減少了SegNet編碼器網路中的引數數量。每個編碼器層有一個對應的解碼器層,因此解碼器網路有13層。最後的解碼器輸出被饋送到一個多類軟最大分類器,為每個畫素獨立產生類概率。

相比於其他網路:

DeconvNet的引數化要大得多,需要更多的計算資源,而且端到端訓練也比較困難,這主要是由於使用了完全連線的層(儘管是以卷積的方式)。與SegNet相比,U-Net(為醫學成像社群提出)不重用池索引,而是將整個feature map(以更多記憶體為代價)傳輸到相應的解碼器,並將它們連線到上取樣(通過反褶積)解碼器feature map。U-Net中的vggnet結構沒有conv5和max-pool5 。另一方面,SegNet使用來自VGG net的所有預訓練卷積層權重作為預訓練權重。

下面不同網路比較的結果。

Max pooling 索引

假設下圖中a、b、c、d對應於feature map中的值。SegNet使用Max pooling 索引向上取樣(不需要學習)特徵對映,並與可訓練的 decoder filters 組進行卷積。FCN通過學習對輸入的feature map進行解卷積,並新增相應的encoder feature map來產生decoder output,從而對FCN進行upsamples。該feature map是對應編碼器中的max-pooling層(包括子取樣)的輸出。注意,FCN中沒有可訓練的 decoder filters 。這裡的 decoder filters 就是指解碼器中的輸入輸出大小不變的那部分卷積層,而不是轉置卷積或者解卷積。

程式碼詳解

程式碼地址:https://github.com/delta-onera/segnet_pytorch/blob/master/segnet.py

程式碼是基於pytorch,因為基本的結構是vgg16作為編碼器,與vgg16相反的結構作為解碼器,所以程式碼可以直觀理解,唯一特別注意的是如何實現max pooling 索引的問題。

這裡有兩個pytorch裡的函式已經幫助我們實現了功能,分別是max_pool2d,max_unpool2d。

torch.nn.functional.max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False)

torch.nn.functional.max_unpool2d(input, indices, kernel_size, stride=None, padding=0, output_size=None)

引數:
- input – 輸入的張量 (minibatch x in_channels x iH x iW)
- kernel_size – 池化區域的大小,可以是單個數字或者元組 (kh x kw)
- stride – 池化操作的步長,可以是單個數字或者元組 (sh x sw)。預設等於核的大小
- padding – 在輸入上隱式的零填充,可以是單個數字或者一個元組 (padh x padw),預設: 0
- ceil_mode – 定義空間輸出形狀的操作
- count_include_pad – 除以原始非填充影象內的元素數量或kh * kw

-return_indices – 返回索引

若返回索引設定為True,max_pool2d就會返回輸出和索引值,這個索引值就能被max_unpool2d所設定為indices。

就如下面那樣:

        x5p, id5 = F.max_pool2d(x53,kernel_size=2, stride=2,return_indices=True)

        x5d = F.max_unpool2d(x5p, id5, kernel_size=2, stride=2)

最後奉上完整程式碼。

import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict


class SegNet(nn.Module):
    def __init__(self,input_nbr,label_nbr):
        super(SegNet, self).__init__()

        batchNorm_momentum = 0.1

        self.conv11 = nn.Conv2d(input_nbr, 64, kernel_size=3, padding=1)
        self.bn11 = nn.BatchNorm2d(64, momentum= batchNorm_momentum)
        self.conv12 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.bn12 = nn.BatchNorm2d(64, momentum= batchNorm_momentum)

        self.conv21 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn21 = nn.BatchNorm2d(128, momentum= batchNorm_momentum)
        self.conv22 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn22 = nn.BatchNorm2d(128, momentum= batchNorm_momentum)

        self.conv31 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.bn31 = nn.BatchNorm2d(256, momentum= batchNorm_momentum)
        self.conv32 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.bn32 = nn.BatchNorm2d(256, momentum= batchNorm_momentum)
        self.conv33 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.bn33 = nn.BatchNorm2d(256, momentum= batchNorm_momentum)

        self.conv41 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
        self.bn41 = nn.BatchNorm2d(512, momentum= batchNorm_momentum)
        self.conv42 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn42 = nn.BatchNorm2d(512, momentum= batchNorm_momentum)
        self.conv43 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn43 = nn.BatchNorm2d(512, momentum= batchNorm_momentum)

        self.conv51 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn51 = nn.BatchNorm2d(512, momentum= batchNorm_momentum)
        self.conv52 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn52 = nn.BatchNorm2d(512, momentum= batchNorm_momentum)
        self.conv53 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn53 = nn.BatchNorm2d(512, momentum= batchNorm_momentum)

        self.conv53d = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn53d = nn.BatchNorm2d(512, momentum= batchNorm_momentum)
        self.conv52d = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn52d = nn.BatchNorm2d(512, momentum= batchNorm_momentum)
        self.conv51d = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn51d = nn.BatchNorm2d(512, momentum= batchNorm_momentum)

        self.conv43d = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn43d = nn.BatchNorm2d(512, momentum= batchNorm_momentum)
        self.conv42d = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn42d = nn.BatchNorm2d(512, momentum= batchNorm_momentum)
        self.conv41d = nn.Conv2d(512, 256, kernel_size=3, padding=1)
        self.bn41d = nn.BatchNorm2d(256, momentum= batchNorm_momentum)

        self.conv33d = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.bn33d = nn.BatchNorm2d(256, momentum= batchNorm_momentum)
        self.conv32d = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.bn32d = nn.BatchNorm2d(256, momentum= batchNorm_momentum)
        self.conv31d = nn.Conv2d(256,  128, kernel_size=3, padding=1)
        self.bn31d = nn.BatchNorm2d(128, momentum= batchNorm_momentum)

        self.conv22d = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn22d = nn.BatchNorm2d(128, momentum= batchNorm_momentum)
        self.conv21d = nn.Conv2d(128, 64, kernel_size=3, padding=1)
        self.bn21d = nn.BatchNorm2d(64, momentum= batchNorm_momentum)

        self.conv12d = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.bn12d = nn.BatchNorm2d(64, momentum= batchNorm_momentum)
        self.conv11d = nn.Conv2d(64, label_nbr, kernel_size=3, padding=1)


    def forward(self, x):

        # Stage 1
        x11 = F.relu(self.bn11(self.conv11(x)))
        x12 = F.relu(self.bn12(self.conv12(x11)))
        x1p, id1 = F.max_pool2d(x12,kernel_size=2, stride=2,return_indices=True)

        # Stage 2
        x21 = F.relu(self.bn21(self.conv21(x1p)))
        x22 = F.relu(self.bn22(self.conv22(x21)))
        x2p, id2 = F.max_pool2d(x22,kernel_size=2, stride=2,return_indices=True)

        # Stage 3
        x31 = F.relu(self.bn31(self.conv31(x2p)))
        x32 = F.relu(self.bn32(self.conv32(x31)))
        x33 = F.relu(self.bn33(self.conv33(x32)))
        x3p, id3 = F.max_pool2d(x33,kernel_size=2, stride=2,return_indices=True)

        # Stage 4
        x41 = F.relu(self.bn41(self.conv41(x3p)))
        x42 = F.relu(self.bn42(self.conv42(x41)))
        x43 = F.relu(self.bn43(self.conv43(x42)))
        x4p, id4 = F.max_pool2d(x43,kernel_size=2, stride=2,return_indices=True)

        # Stage 5
        x51 = F.relu(self.bn51(self.conv51(x4p)))
        x52 = F.relu(self.bn52(self.conv52(x51)))
        x53 = F.relu(self.bn53(self.conv53(x52)))
        x5p, id5 = F.max_pool2d(x53,kernel_size=2, stride=2,return_indices=True)


        # Stage 5d
        x5d = F.max_unpool2d(x5p, id5, kernel_size=2, stride=2)
        x53d = F.relu(self.bn53d(self.conv53d(x5d)))
        x52d = F.relu(self.bn52d(self.conv52d(x53d)))
        x51d = F.relu(self.bn51d(self.conv51d(x52d)))

        # Stage 4d
        x4d = F.max_unpool2d(x51d, id4, kernel_size=2, stride=2)
        x43d = F.relu(self.bn43d(self.conv43d(x4d)))
        x42d = F.relu(self.bn42d(self.conv42d(x43d)))
        x41d = F.relu(self.bn41d(self.conv41d(x42d)))

        # Stage 3d
        x3d = F.max_unpool2d(x41d, id3, kernel_size=2, stride=2)
        x33d = F.relu(self.bn33d(self.conv33d(x3d)))
        x32d = F.relu(self.bn32d(self.conv32d(x33d)))
        x31d = F.relu(self.bn31d(self.conv31d(x32d)))

        # Stage 2d
        x2d = F.max_unpool2d(x31d, id2, kernel_size=2, stride=2)
        x22d = F.relu(self.bn22d(self.conv22d(x2d)))
        x21d = F.relu(self.bn21d(self.conv21d(x22d)))

        # Stage 1d
        x1d = F.max_unpool2d(x21d, id1, kernel_size=2, stride=2)
        x12d = F.relu(self.bn12d(self.conv12d(x1d)))
        x11d = self.conv11d(x12d)

        return x11d

    def load_from_segnet(self, model_path):
        s_dict = self.state_dict()# create a copy of the state dict
        th = torch.load(model_path).state_dict() # load the weigths
        # for name in th:
            # s_dict[corresp_name[name]] = th[name]
        self.load_state_dict(th)