1. 程式人生 > >PyTorch—torchvision.models匯入預訓練模型與殘差網路講解

PyTorch—torchvision.models匯入預訓練模型與殘差網路講解

文章目錄


PyTorch框架中torchvision模組下有:torchvision.datasets、torchvision.models、torchvision.transforms這3個子包。
關於詳情請參考官網:
http://pytorch.org/docs/master/torchvision/index.html。

具體程式碼可以參考github: https://github.com/pytorch/vision/tree/master/torchvision。

torchvision.models

此模組下有常用的 alexnet、densenet、inception、resnet、squeezenet、vgg(關於網路詳情請檢視)等常用的網路結構,並且提供了預訓練模型,我們可以通過簡單呼叫來讀取網路結構和預訓練模型,同時使用fine tuning(微調)來使用。
關於 fine tuning 可以檢視

https://blog.csdn.net/hjxu2016/article/details/78424370

1. 模組呼叫
import torchvision

"""
    如果你需要用預訓練模型,設定pretrained=True
    如果你不需要用預訓練模型,設定pretrained=False,預設是False,你可以不寫
"""
model = torchvision.models.resnet50(pretrained=True) 
model = torchvision.models.resnet50() 

# 你也可以匯入densenet模型。且不需要是預訓練的模型
model =
torchvision.models.densenet169(pretrained=False)
2. 原始碼解析

以匯入resnet50為例,介紹具體匯入模型時候的原始碼。
執行 model = torchvision.models.resnet50(pretrained=True)的時候,是通過models包下的resnet.py指令碼進行的,原始碼如下:

首先是匯入必要的庫,其中model_zoo是和匯入預訓練模型相關的包,另外all變數定義了可以從外部import的函式名或類名。這也是前面為什麼可以用torchvision.models.resnet50()來呼叫的原因。
model_urls這個字典是預訓練模型的下載地址。

import torch.nn as nn
import math
import torch.utils.model_zoo as model_zoo

__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
           'resnet152']

model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}

接下來就是resnet50這個函數了,引數pretrained預設是False。

  1. model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)是構建網路結構,Bottleneck是另外一個構建bottleneck的類,在ResNet網路結構的構建中有很多重複的子結構,這些子結構就是通過Bottleneck類來構建的,後面會介紹。
  2. 如果引數pretrained是True,那麼就會通過model_zoo.py中的load_url函式根據model_urls字典下載或匯入相應的預訓練模型。
  3. 通過呼叫model的load_state_dict方法用預訓練的模型引數來初始化你構建的網路結構,這個方法就是PyTorch中通用的用一個模型的引數初始化另一個模型的層的操作。load_state_dict方法還有一個重要的引數是strict,該引數預設是True,表示預訓練模型的層和你的網路結構層嚴格對應相等(比如層名和維度)。
def resnet50(pretrained=False, **kwargs):
    model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
    return model

其他resnet18、resnet101等函式和resnet50基本類似。
差別主要是在:
1、構建網路結構的時候block的引數不一樣,比如resnet18中是[2, 2, 2, 2],resnet101中是[3, 4, 23, 3]。
2、呼叫的block類不一樣,比如在resnet50、resnet101、resnet152中呼叫的是Bottleneck類,而在resnet18和resnet34中呼叫的是BasicBlock類,這兩個類的區別主要是在residual結果中卷積層的數量不同,這個是和網路結構相關的,後面會詳細介紹。
3、如果下載預訓練模型的話,model_urls字典的鍵不一樣,對應不同的預訓練模型。因此接下來分別看看如何構建網路結構和如何匯入預訓練模型。

# pretrained (bool): If True, returns a model pre-trained on ImageNet

def resnet18(pretrained=False, **kwargs):
    model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
    return model

def resnet101(pretrained=False, **kwargs):
    model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
    return model
3. ResNet類

繼承PyTorch中網路的基類:torch.nn.Module :

  • 構建ResNet網路是通過ResNet這個類進行的。
  • 其次主要的是重寫初始化__init__()forward()
    __init __()中主要是定義一些層的引數。
    forward()中主要是定義資料在層之間的流動順序,也就是層的連線順序。
    另外還可以在類中定義其他私有方法用來模組化一些操作,比如這裡的_make_layer()是用來構建ResNet網路中的4個blocks。
    _make_layer():
    第一個輸入block是Bottleneck或BasicBlock類,
    第二個輸入是該blocks的輸出channel,
    第三個輸入是每個blocks中包含多少個residual子結構,因此layers這個列表就是前面resnet50的[3, 4, 6, 3]。
    _make_layer()方法中比較重要的兩行程式碼是:
    1、layers.append(block(self.inplanes, planes, stride, downsample)),該部分是將每個blocks的第一個residual結構儲存在layers列表中。
    2、 for i in range(1, blocks): layers.append(block(self.inplanes, planes)),該部分是將每個blocks的剩下residual 結構儲存在layers列表中,這樣就完成了一個blocks的構造。
    這兩行程式碼中都是通過Bottleneck這個類來完成每個residual的構建,接下來介紹Bottleneck類。
class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes=1000):
        self.inplanes = 64
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AvgPool2d(7, stride=1)
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x
4. Bottlenect類

從前面的ResNet類可以看出,在構造ResNet網路的時候,最重要的是Bottleneck這個類,因為ResNet是由residual結構組成的,而Bottleneck類就是完成residual結構的構建。同樣Bottlenect還是繼承了torch.nn.Module類,且重寫了__init__和forward方法。從forward方法可以看出,bottleneck 就是我們熟悉的3個主要的卷積層、BN層和啟用層,最後的out += residual就是element-wise add的操作。

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)
        return out
5. BasicBlock類

BasicBlock類和Bottleneck類類似,BasicBlock類主要是用來構建ResNet18和ResNet34網路,因為這兩個網路的residual結構只包含兩個卷積層,沒有Bottleneck類中的bottleneck概念。因此在該類中,第一個卷積層採用的是kernel_size=3的卷積,如conv3x3函式所示。

def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)

class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)
        return out
6. 獲取預訓練模型

前面提到這一行程式碼:
if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])),主要就是通過model_zoo.py中的load_url函式根據model_urls字典匯入相應的預訓練模型,models_zoo.py指令碼的github地址:https://github.com/pytorch/pytorch/blob/master/torch/utils/model_zoo.py。
load_url函式原始碼如下。

  • 首先model_dir是下載模型儲存地址,如果沒有指定則儲存在專案的.torch目錄下,最好指定。cached_file是儲存模型的路徑加上模型名稱。
  • 接下來的 if not os.path.exists(cached_file)語句用來判斷是否指定目錄下已經存在要下載模型,如果已經存在,就直接呼叫torch.load介面匯入模型,如果不存在,則從網上下載。
  • 下載是通過_download_url_to_file(url, cached_file, hash_prefix, progress=progress)進行的,不再細講。重點在於模型匯入是通過torch.load()介面來進行的,不管你的模型是從網上下載的還是本地已有的。
def load_url(url, model_dir=None, map_location=None, progress=True):
    """
    Args:
        url (string): URL of the object to download
        model_dir (string, optional): directory in which to save the object
        map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load)
        progress (bool, optional): whether or not to display a progress bar to stderr

    Example:
        >>> state_dict = torch.utils.model_zoo.load_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')

    """
    if model_dir is None:
        torch_home = os.path.expanduser(os.getenv('TORCH_HOME', '~/.torch'))
        model_dir = os.getenv('TORCH_MODEL_ZOO', os.path.join(torch_home, 'models'))
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    parts = urlparse(url)
    filename = os.path.basename(parts.path)
    cached_file = os.path.join(model_dir, filename)
    if not os.path.exists(cached_file):
        sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
        hash_prefix = HASH_REGEX.search(filename).group(1)
        _download_url_to_file(url, cached_file, hash_prefix, progress=progress)
    return torch.load(cached_file, map_location=map_location)

鳴謝
https://blog.csdn.net/u014380165/article/details/79119664