1. 程式人生 > >【小白學PyTorch】12 SENet詳解及PyTorch實現

【小白學PyTorch】12 SENet詳解及PyTorch實現

文章來自微信公眾號【機器學習煉丹術】。我是煉丹兄,有什麼問題都可以來找我交流,近期建立了微信交流群,也在朋友圈抽獎贈書十多本了。我的微信是cyx645016617,歡迎各位朋友。 參考目錄: @[toc] 上一節課講解了MobileNet的一個DSC深度可分離卷積的概念,希望大家可以在實際的任務中使用這種方法,現在再來介紹EfficientNet的另外一個基礎知識—,**Squeeze-and-Excitation Networks壓縮-啟用網路** ## 1 網路結構 ![](https://img-blog.csdnimg.cn/img_convert/b26f49ae2738723f25913723d584c67a.png) 可以看出來,左邊的圖是一個典型的Resnet的結構,**Resnet這個殘差結構特徵圖求和而不是通道拼接,這一點可以注意一下** 這個SENet結構式融合在殘差網路上的,我來分析一下上圖右邊的結構: - 輸出特徵圖假設shape是$W \times H \times C$的; - 一般的Resnet就是這個特徵圖經過殘差網路的基本組塊,得到了輸出特徵圖,然後輸入特徵圖和輸入特徵圖通過殘差結構連在一起(通過加和的方式連在一起); - SE模組就是輸出特徵圖先經過一個全域性池化層,shape從$W \times H \times C$變成了$1 \times 1 \times C$,**這個就變成了一個全連線層的輸入啦** - 壓縮Squeeze:先放到第一個全連線層裡面,輸入$C$個元素,輸出$\frac{C}{r}$,r是一個事先設定的引數; - 啟用Excitation:在接上一個全連線層,輸入是$\frac{C}{r}$個神經元,輸出是$C$個元素,實現啟用的過程; - 現在我們有了一個$C$個元素的經過了兩層全連線層的輸出,這個C個元素,剛好表示的是原來輸出特徵圖$W \times H \times C$中C個通道的一個權重值,所以我們讓C個通道上的畫素值分別乘上全連線的C個輸出,這個步驟在圖中稱為**Scale**。**而這個調整過特徵圖每一個通道權重的特徵圖是SE-Resnet的輸出特徵圖,之後再考慮殘差接連的步驟。** 在原文論文中還有另外一個結構圖,供大家參考: ![](https://img-blog.csdnimg.cn/img_convert/b3ce4902f67b9d99f428b095fbb00171.png) ## 2 引數量分析 每一個卷積層都增加了額外的兩個全連線層,不夠好在全連線層的引數非常小,所以直觀來看應該整體不會增加很多的計算量。 Resnet50的引數量為25M的大小,增加了SE模組,增加了2.5M的引數量,所以大概增加了10%左右,而且這2.5M的引數主要集中在final stage的se模組,因為在最後一個卷積模組中,特徵圖擁有最大的通道數,所以這個final stage的引數量佔據了增加的2.5M引數的96%。 這裡放一個幾個網路結構的對比: ![](https://img-blog.csdnimg.cn/img_convert/9115532d3e87742fcc03798f11ee87b4.png) ## 3 PyTorch實現與解析 先上完整版的程式碼,大家可以複製本地IDE跑一跑,如果程式碼有什麼問題可以聯絡我: ```python import torch import torch.nn as nn import torch.nn.functional as F class PreActBlock(nn.Module): def __init__(self, in_planes, planes, stride=1): super(PreActBlock, self).__init__() self.bn1 = nn.BatchNorm2d(in_planes) self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(planes) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) if stride != 1 or in_planes != planes: self.shortcut = nn.Sequential( nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False) ) # SE layers self.fc1 = nn.Conv2d(planes, planes//16, kernel_size=1) self.fc2 = nn.Conv2d(planes//16, planes, kernel_size=1) def forward(self, x): out = F.relu(self.bn1(x)) shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x out = self.conv1(out) out = self.conv2(F.relu(self.bn2(out))) # Squeeze w = F.avg_pool2d(out, out.size(2)) w = F.relu(self.fc1(w)) w = F.sigmoid(self.fc2(w)) # Excitation out = out * w out += shortcut return out class SENet(nn.Module): def __init__(self, block, num_blocks, num_classes=10): super(SENet, self).__init__() self.in_planes = 64 self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(64) self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) self.linear = nn.Linear(512, num_classes) def _make_layer(self, block, planes, num_blocks, stride): strides = [stride] + [1]*(num_blocks-1) layers = [] for stride in strides: layers.append(block(self.in_planes, planes, stride)) self.in_planes = planes return nn.Sequential(*layers) def forward(self, x): out = F.relu(self.bn1(self.conv1(x))) out = self.layer1(out) out = self.layer2(out) out = self.layer3(out) out = self.layer4(out) out = F.avg_pool2d(out, 4) out = out.view(out.size(0), -1) out = self.linear(out) return out def SENet18(): return SENet(PreActBlock, [2,2,2,2]) net = SENet18() y = net(torch.randn(1,3,32,32)) print(y.size()) print(net) ``` 輸出和註解我都整理了一下: ![](https://img-blog.csdnimg.cn/img_convert/9450e51e10ce8cbf62c96b295f9a88c7.png) ![](https://img-blog.csdnimg.cn/img_convert/f5937544516fe5095d3edc7ea25d9