1. 程式人生 > >動態匯入模組,載入預訓練模型,nn.Sequential函式裡面必須是a Module subclass,不能是一個列表或者是其他的迭代器、生成器,雖然這裡麵包含了Module的子類

動態匯入模組,載入預訓練模型,nn.Sequential函式裡面必須是a Module subclass,不能是一個列表或者是其他的迭代器、生成器,雖然這裡麵包含了Module的子類

 

class RES(nn.Module):
    def __init__(self):
        super(RES, self).__init__()
        self.conv1=nn.Conv2d(3,64,kernel_size=7,stride=2,padding=3,bias=False)
        self.bn1=nn.BatchNorm2d(64)
        self.relu=nn.ReLU(inplace=True)
        self.maxpool=nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
        self.conv2=nn.Conv2d(64,128,kernel_size=7,stride=2,padding=3,bias=False)
        self.bn2=nn.BatchNorm2d(128)
    def forward(self,x):
        x=self.conv1(x)
        x=self.bn1(x)
        x=self.relu(x)
        x=self.maxpool(x)
        x=self.conv2(x)
        x=self.bn2(x)
        return x

model=RES()
glb = nn.Sequential(*list(model.children())[:4])

有兩點資料的說明:這個類繼承了Module一定要用super函式

nn.Sequential函式裡面的引數一定是Module的子類,而list:list is not a Module subclass。所以不能當做引數,當然model.children()也是一樣:Module.children is not a Module subclass。這裡的*就起了作用,將list或者children的內容迭代的一個一個的傳進去,效果如下:

 當然,我們還可以像最上面的那樣,選取裡面的幾個Module,例如[:4]也就是第0個到第3個.

 

動態匯入模組,使用importlib.import_module函式實際上是import了一個叫做resnet的檔案,下面的語句相當於 import xxx as resnet

當然這裡的xxx是該檔案的實際路徑

import importlib
resnet = importlib.import_module("torchvision.models.resnet")
resnet18=resnet.resnet18()
resnet34=resnet.resnet34()
resnet50=resnet.resnet50()
resnet101=resnet.resnet101()
resnet152=resnet.resnet152()

其他的模組有:

"""
alexnet檔案
"""
alexnet=importlib.import_module("torchvision.models.alexnet")
alexnet=alexnet.alexnet()
nn.Sequential(*alexnet.children())

"""
vgg檔案
"""
vgg=importlib.import_module("torchvision.models.vgg")
vgg16=vgg.vgg16() # vgg11=vgg.vgg11(),vgg19=vgg.vgg19(),vgg13=vgg.vgg13()以及他們的bn形式
# vgg16_bn=vgg.vgg16_bn(),vgg11_bn=vgg.vgg11_bn(),vgg19_bn=vgg.vgg19_bn(),vgg13_bn=vgg.vgg13_bn()
nn.Sequential(*vgg16.children())

"""
densenet檔案
"""
densenet=importlib.import_module("torchvision.models.densenet")
densenet121=densenet.densenet121() 
# densenet169=densenet.densenet169(),densenet201=densenet.densenet201(),densenet161=densenet.densenet161()
nn.Sequential(*densenet121.children())

"""
inception檔案
"""
inception=importlib.import_module("torchvision.models.inception")
inception_v3=inception.inception_v3()
nn.Sequential(*inception_v3.children())

"""
squeezenet檔案
"""
squeezenet=importlib.import_module("torchvision.models.squeezenet")
squeezenet1_0=inception.squeezenet1_0()
# squeezenet1_0=inception.squeezenet1_1()
nn.Sequential(*squeezenet1_0.children())

還有一種匯入方式,是比較常用的,推薦的:

import torchvision.models as models
models.squeezenet1_0()

"""
models後面直接接的是網路
models的__init__檔案如下
"""
from .alexnet import *
from .resnet import *
from .vgg import *
from .squeezenet import *
from .inception import *
from .densenet import *
"""
可以看出來,匯入的是這5個檔案裡面的函式(類)
*代表想對應檔案的__all__,下面是各個檔案的該屬性以及訓練好的權重
"""
# alexnet
__all__ = ['AlexNet', 'alexnet']
model_urls = {
    'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth',
}
# resnet
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
           'resnet152']
model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}
# vgg
__all__ = [
    'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
    'vgg19_bn', 'vgg19',]
model_urls = {
    'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
    'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
    'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
    'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
    'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
    'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
    'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
    'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
}
# squeezenet
__all__ = ['SqueezeNet', 'squeezenet1_0', 'squeezenet1_1']
model_urls = {
    'squeezenet1_0': 'https://download.pytorch.org/models/squeezenet1_0-a815701f.pth',
    'squeezenet1_1': 'https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth',
}
# inception
__all__ = ['Inception3', 'inception_v3']
model_urls = {
    # Inception v3 ported from TensorFlow
    'inception_v3_google': 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth',
}
# densenet
__all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161']
model_urls = {
    'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth',
    'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth',
    'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth',
    'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth',
}

所有的模型預設都是不載入預訓練模型引數的,怎麼載入預訓練模型引數呢?很簡單,就在括號裡面的pretrained設定成True,如果僅僅是需要該結構而不需要預訓練模型引數作為初始化,那麼pretrained=False。

resnet50 = models.resnet50(pretrained=True)

推薦!這裡有一篇比較綜合https://blog.csdn.net/weixin_41278720/article/details/80759933

其中可以補充一點就是將引數進行下載,相比載入模型來說更加的節省資源

    import torch.utils.model_zoo as model_zoo

    def _load_pretrained_model(self):
        pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
                                           '/home/zzp/SSD_ping/my-root-path/My-core-python/PretrainedWeights')
        model_dict = {}
        state_dict = self.state_dict()
        for k, v in pretrain_dict.items():
            if k in state_dict:
                model_dict[k] = v
        state_dict.update(model_dict)
        self.load_state_dict(state_dict)