1. 程式人生 > >GoogLeNet網絡的Pytorch實現

GoogLeNet網絡的Pytorch實現

一個 ins UNC __main__ target dds rgs ogr ogl

1.文章原文地址

Going deeper with convolutions

2.文章摘要

我們提出了一種代號為Inception的深度卷積神經網絡,它在ILSVRC2014的分類和檢測任務上都取得當前最佳成績。這種結構的主要特點是提高了網絡內部計算資源的利用率。這是通過精心的設計實現的,它允許增加網絡的深度和寬度,同時保持計算預算不變。為了提高效果,這個網絡的架構確定是基於Hebbian原則和多尺度處理的直覺。其中一個典型的實例用於提交到ILSVRC2014上,我們稱之為GoogLeNet,它是一個22層的深度網絡,該網絡的效果通過分類和檢測任務來加以評估。

3.網絡結構

技術分享圖片

技術分享圖片

4.Pytorch實現

  1 import warnings
  2 from collections import namedtuple
  3 import torch
  4 import torch.nn as nn
  5 import torch.nn.functional as F
  6 from torch.utils.model_zoo import load_url as load_state_dict_from_url
  7 from torchsummary import summary
  8 
  9 __all__ = [GoogLeNet, googlenet]
 10
11 model_urls = { 12 # GoogLeNet ported from TensorFlow 13 googlenet: https://download.pytorch.org/models/googlenet-1378be20.pth, 14 } 15 16 _GoogLeNetOuputs = namedtuple(GoogLeNetOuputs, [logits, aux_logits2, aux_logits1]) 17 18 19 def googlenet(pretrained=False, progress=True, **kwargs):
20 r"""GoogLeNet (Inception v1) model architecture from 21 `"Going Deeper with Convolutions" <http://arxiv.org/abs/1409.4842>`_. 22 Args: 23 pretrained (bool): If True, returns a model pre-trained on ImageNet 24 progress (bool): If True, displays a progress bar of the download to stderr 25 aux_logits (bool): If True, adds two auxiliary branches that can improve training. 26 Default: *False* when pretrained is True otherwise *True* 27 transform_input (bool): If True, preprocesses the input according to the method with which it 28 was trained on ImageNet. Default: *False* 29 """ 30 if pretrained: 31 if transform_input not in kwargs: 32 kwargs[transform_input] = True 33 if aux_logits not in kwargs: 34 kwargs[aux_logits] = False 35 if kwargs[aux_logits]: 36 warnings.warn(auxiliary heads in the pretrained googlenet model are NOT pretrained, 37 so make sure to train them) 38 original_aux_logits = kwargs[aux_logits] 39 kwargs[aux_logits] = True 40 kwargs[init_weights] = False 41 model = GoogLeNet(**kwargs) 42 state_dict = load_state_dict_from_url(model_urls[googlenet], 43 progress=progress) 44 model.load_state_dict(state_dict) 45 if not original_aux_logits: 46 model.aux_logits = False 47 del model.aux1, model.aux2 48 return model 49 50 return GoogLeNet(**kwargs) 51 52 53 class GoogLeNet(nn.Module): 54 55 def __init__(self, num_classes=1000, aux_logits=True, transform_input=False, init_weights=True): 56 super(GoogLeNet, self).__init__() 57 self.aux_logits = aux_logits 58 self.transform_input = transform_input 59 60 self.conv1 = BasicConv2d(3, 64, kernel_size=7, stride=2, padding=3) 61 self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True) #向上取整 62 self.conv2 = BasicConv2d(64, 64, kernel_size=1) 63 self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1) 64 self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True) 65 66 self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32) 67 self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64) 68 self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True) 69 70 self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64) 71 self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64) 72 self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64) 73 self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64) 74 self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128) 75 self.maxpool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 76 77 self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128) 78 self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128) 79 80 if aux_logits: 81 self.aux1 = InceptionAux(512, num_classes) 82 self.aux2 = InceptionAux(528, num_classes) 83 84 self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 85 self.dropout = nn.Dropout(0.2) 86 self.fc = nn.Linear(1024, num_classes) 87 88 if init_weights: 89 self._initialize_weights() 90 91 def _initialize_weights(self): 92 for m in self.modules(): 93 if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 94 import scipy.stats as stats 95 X = stats.truncnorm(-2, 2, scale=0.01) 96 values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype) 97 values = values.view(m.weight.size()) 98 with torch.no_grad(): 99 m.weight.copy_(values) 100 elif isinstance(m, nn.BatchNorm2d): 101 nn.init.constant_(m.weight, 1) 102 nn.init.constant_(m.bias, 0) 103 104 def forward(self, x): 105 if self.transform_input: 106 x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 107 x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 108 x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 109 x = torch.cat((x_ch0, x_ch1, x_ch2), 1) 110 111 # N x 3 x 224 x 224 112 x = self.conv1(x) 113 # N x 64 x 112 x 112 114 x = self.maxpool1(x) 115 # N x 64 x 56 x 56 116 x = self.conv2(x) 117 # N x 64 x 56 x 56 118 x = self.conv3(x) 119 # N x 192 x 56 x 56 120 x = self.maxpool2(x) 121 122 # N x 192 x 28 x 28 123 x = self.inception3a(x) 124 # N x 256 x 28 x 28 125 x = self.inception3b(x) 126 # N x 480 x 28 x 28 127 x = self.maxpool3(x) 128 # N x 480 x 14 x 14 129 x = self.inception4a(x) 130 # N x 512 x 14 x 14 131 if self.training and self.aux_logits: 132 aux1 = self.aux1(x) 133 134 x = self.inception4b(x) 135 # N x 512 x 14 x 14 136 x = self.inception4c(x) 137 # N x 512 x 14 x 14 138 x = self.inception4d(x) 139 # N x 528 x 14 x 14 140 if self.training and self.aux_logits: 141 aux2 = self.aux2(x) 142 143 x = self.inception4e(x) 144 # N x 832 x 14 x 14 145 x = self.maxpool4(x) 146 # N x 832 x 7 x 7 147 x = self.inception5a(x) 148 # N x 832 x 7 x 7 149 x = self.inception5b(x) 150 # N x 1024 x 7 x 7 151 152 x = self.avgpool(x) 153 # N x 1024 x 1 x 1 154 x = x.view(x.size(0), -1) 155 # N x 1024 156 x = self.dropout(x) 157 x = self.fc(x) 158 # N x 1000 (num_classes) 159 if self.training and self.aux_logits: 160 return _GoogLeNetOuputs(x, aux2, aux1) 161 return x 162 163 164 class Inception(nn.Module): #Inception模塊 165 166 def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj): 167 super(Inception, self).__init__() 168 169 self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1) 170 171 self.branch2 = nn.Sequential( 172 BasicConv2d(in_channels, ch3x3red, kernel_size=1), 173 BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1) 174 ) 175 176 self.branch3 = nn.Sequential( 177 BasicConv2d(in_channels, ch5x5red, kernel_size=1), 178 BasicConv2d(ch5x5red, ch5x5, kernel_size=3, padding=1) 179 ) 180 181 self.branch4 = nn.Sequential( 182 nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True), 183 BasicConv2d(in_channels, pool_proj, kernel_size=1) 184 ) 185 186 def forward(self, x): 187 branch1 = self.branch1(x) 188 branch2 = self.branch2(x) 189 branch3 = self.branch3(x) 190 branch4 = self.branch4(x) 191 192 outputs = [branch1, branch2, branch3, branch4] 193 return torch.cat(outputs, 1) 194 195 196 class InceptionAux(nn.Module): #輔助分支 197 198 def __init__(self, in_channels, num_classes): 199 super(InceptionAux, self).__init__() 200 self.conv = BasicConv2d(in_channels, 128, kernel_size=1) 201 202 self.fc1 = nn.Linear(2048, 1024) 203 self.fc2 = nn.Linear(1024, num_classes) 204 205 def forward(self, x): 206 # aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14 207 x = F.adaptive_avg_pool2d(x, (4, 4)) 208 # aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4 209 x = self.conv(x) 210 # N x 128 x 4 x 4 211 x = x.view(x.size(0), -1) 212 # N x 2048 213 x = F.relu(self.fc1(x), inplace=True) 214 # N x 1024 215 x = F.dropout(x, 0.7, training=self.training) 216 # N x 1024 217 x = self.fc2(x) 218 # N x num_classes 219 220 return x 221 222 223 class BasicConv2d(nn.Module): #Conv2d+BN+Relu 224 225 def __init__(self, in_channels, out_channels, **kwargs): 226 super(BasicConv2d, self).__init__() 227 self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) 228 self.bn = nn.BatchNorm2d(out_channels, eps=0.001) 229 230 def forward(self, x): 231 x = self.conv(x) 232 x = self.bn(x) 233 return F.relu(x, inplace=True) 234 235 236 if __name__=="__main__": 237 model=googlenet() 238 print(model,(3,224,224))

參考

https://github.com/pytorch/vision/tree/master/torchvision/models

GoogLeNet網絡的Pytorch實現