1. 程式人生 > >從頭學pytorch(二十一):全連線網路dense net

從頭學pytorch(二十一):全連線網路dense net

DenseNet

論文傳送門,這篇論文是CVPR 2017的最佳論文.

resnet一文裡說了,resnet是具有里程碑意義的.densenet就是受resnet的啟發提出的模型.

resnet中是把不同層的feature map相應元素的值直接相加.而densenet是將channel維上的feature map直接concat在一起,從而實現了feature的複用.如下所示:


注意,是連線dense block內輸出層前面所有層的輸出,不是隻有輸出層的前一層

網路結構


首先實現DenseBlock

先解釋幾個名詞

  • bottleneck layer  
    即上圖中紅圈的1x1卷積核.主要目的是對輸入在channel維度做降維.減少運算量.


    卷積核的數量為4k,k為該layer輸出的feature map的數量(也就是3x3卷積核的數量)

  • growth rate
    即上圖中黑圈處3x3卷積核的數量.假設3x3卷積核的數量為k,則每個這種3x3卷積後,都得到一個channel=k的輸出.假如一個denseblock有m組這種結構,輸入的channel為n的話,則做完一次連線操作後得到的輸出的channel為n + k + k +...+k = n+m*k.所以又叫做growth rate.

  • conv  
    論文裡的conv指的是BN-ReLU-Conv

實現DenseBlock

DenseLayer

class DenseLayer(nn.Module):
    def __init__(self,in_channels,bottleneck_size,growth_rate):
        super(DenseLayer,self).__init__()
        count_of_1x1 = bottleneck_size
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv1x1 = nn.Conv2d(in_channels,count_of_1x1,kernel_size=1)

        self.bn2 = nn.BatchNorm2d(count_of_1x1)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv3x3 = nn.Conv2d(count_of_1x1,growth_rate,kernel_size=3,padding=1)
        
    def forward(self,*prev_features):
        # for f in prev_features:
        #     print(f.shape)

        input = torch.cat(prev_features,dim=1)
        # print(input.device,input.shape)
        # for param in self.bn1.parameters():
        #     print(param.device)
        # print(list())
        bottleneck_output = self.conv1x1(self.relu1(self.bn1(input)))
        out = self.conv3x3(self.relu2(self.bn2(bottleneck_output)))
        
        return out

首先是1x1卷積,然後是3x3卷積.3x3卷積核的數量即growth_rate,bottleneck_size即1x1卷積核數量.論文裡是bottleneck_size=4xgrowth_rate的關係. 注意forward函式的實現

    def forward(self,*prev_features):
        # for f in prev_features:
        #     print(f.shape)

        input = torch.cat(prev_features,dim=1)
        # print(input.device,input.shape)
        # for param in self.bn1.parameters():
        #     print(param.device)
        # print(list())
        bottleneck_output = self.conv1x1(self.relu1(self.bn1(input)))
        out = self.conv3x3(self.relu2(self.bn2(bottleneck_output)))
        
        return out

我們傳進來的是一個元祖,其含義是[block的輸入,layer1輸出,layer2輸出,...].前面說過了,一個dense block內的每一個layer的輸入是前面所有layer的輸出和該block的輸入在channel維度上的連線.這樣就使得不同layer的feature map得到了充分的利用.

tips:
函式引數帶*表示可以傳入任意多的引數,這些引數被組織成元祖的形式,比如

## var-positional parameter
## 定義的時候,我們需要新增單個星號作為字首
def func(arg1, arg2, *args):
    print arg1, arg2, args
 
## 呼叫的時候,前面兩個必須在前面
## 前兩個引數是位置或關鍵字引數的形式
## 所以你可以使用這種引數的任一合法的傳遞方法
func("hello", "Tuple, values is:", 2, 3, 3, 4)
 
## Output:
## hello Tuple, values is: (2, 3, 3, 4)
## 多餘的引數將自動被放入元組中提供給函式使用
 
## 如果你需要傳遞元組給函式
## 你需要在傳遞的過程中新增*號
## 請看下面例子中的輸出差異:
 
func("hello", "Tuple, values is:", (2, 3, 3, 4))
 
## Output:
## hello Tuple, values is: ((2, 3, 3, 4),)
 
func("hello", "Tuple, values is:", *(2, 3, 3, 4))
 
## Output:
## hello Tuple, values is: (2, 3, 3, 4)

DenseBlock

class DenseBlock(nn.Module):
    def __init__(self,in_channels,layer_counts,growth_rate):
        super(DenseBlock,self).__init__()
        self.layer_counts = layer_counts
        self.layers = []
        for i in range(layer_counts):
            curr_input_channel = in_channels + i*growth_rate
            bottleneck_size = 4*growth_rate #論文裡設定的1x1卷積核是3x3卷積核的4倍.
            layer = DenseLayer(curr_input_channel,bottleneck_size,growth_rate).cuda()       
            self.layers.append(layer)

    def forward(self,init_features):
        features = [init_features]
        for layer in self.layers:
            layer_out = layer(*features) #注意引數是*features不是features
            features.append(layer_out)

        return torch.cat(features, 1)

一個Dense Block由多個Layer組成.這裡注意forward的實現,init_features即該block的輸入,然後每個layer都會得到一個輸出.第n個layer的輸入由輸入和前n-1個layer的輸出在channel維度上連線組成.

最後,該block的輸出為各個layer的輸出為輸入以及各個layer的輸出在channel維度上連線而成.

TransitionLayer

很顯然,dense block的計算方式會使得channel維度過大,所以每一個dense block之後要通過1x1卷積在channel維度降維.

class TransitionLayer(nn.Sequential):
    def __init__(self, in_channels, out_channels):
        super(TransitionLayer, self).__init__()
        self.add_module('norm', nn.BatchNorm2d(in_channels))
        self.add_module('relu', nn.ReLU(inplace=True))
        self.add_module('conv', nn.Conv2d(in_channels, out_channels,kernel_size=1, stride=1, bias=False))
        self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))

Dense Net

dense net的基本元件我們已經實現了.下面就可以實現dense net了.

class DenseNet(nn.Module):
    def __init__(self,in_channels,num_classes,block_config):
        super(DenseNet,self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels,64,kernel_size=7,stride=2,padding=3),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        self.pool1 = nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
        
        self.dense_block_layers = nn.Sequential()

        block_in_channels = in_channels
        growth_rate = 32
        for i,layers_counts in enumerate(block_config):
            block = DenseBlock(in_channels=block_in_channels,layer_counts=layers_counts,growth_rate=growth_rate)
            self.dense_block_layers.add_module('block%d' % (i+1),block)

            block_out_channels = block_in_channels + layers_counts*growth_rate
            transition = TransitionLayer(block_out_channels,block_out_channels//2)
            if i != len(block_config): #最後一個dense block後沒有transition layer
                self.dense_block_layers.add_module('transition%d' % (i+1),transition)

            block_in_channels = block_out_channels // 2 #更新下一個dense block的in_channels
        
        self.avg_pool = nn.AdaptiveAvgPool2d(output_size=(1,1))

        self.fc = nn.Linear(block_in_channels,num_classes)


    def forward(self,x):
        out = self.conv1(x)
        out = self.pool1(x)
        for layer in self.dense_block_layers:
            out = layer(out) 
            # print(out.shape)
        out = self.avg_pool(out)
        out = torch.flatten(out,start_dim=1) #相當於out = out.view((x.shape[0],-1))
        out = self.fc(out)

        return out

首先和resnet一樣,首先是7x7卷積接3x3,stride=2的最大池化,然後就是不斷地dense block + tansition.得到feature map以後用全域性平均池化得到n個feature.然後給全連線層做分類使用.

可以用

X=torch.randn(1,3,224,224).cuda()
block_config = [6,12,24,16]
net = DenseNet(3,10,block_config)
net = net.cuda()
out = net(X)
print(out.shape)

測試一下,輸出如下,可以看出feature map的變化情況.最終得到508x7x7的feature map.全域性平均池化後,得到508個特徵,通過線性迴歸得到10個類別.

torch.Size([1, 195, 112, 112])
torch.Size([1, 97, 56, 56])
torch.Size([1, 481, 56, 56])
torch.Size([1, 240, 28, 28])
torch.Size([1, 1008, 28, 28])
torch.Size([1, 504, 14, 14])
torch.Size([1, 1016, 14, 14])
torch.Size([1, 508, 7, 7])
torch.Size([1, 10])

總結:
核心就是dense block內每一個layer都複用了之前的layer得到的feature map,因為底層細節的feature被複用,所以使得模型的特徵提取能力更強. 當然壞處就是計算量大,視訊記憶體消耗