1. 程式人生 > >PyTorch中使用預訓練的模型初始化網路的一部分引數(增減網路層,修改某層引數等) 固定引數

PyTorch中使用預訓練的模型初始化網路的一部分引數(增減網路層,修改某層引數等) 固定引數

在預訓練網路的基礎上,修改部分層得到自己的網路,通常我們需要解決的問題包括: 
1. 從預訓練的模型載入引數 
2. 對新網路兩部分設定不同的學習率,主要訓練自己新增的層 

一. 載入引數的方法: 
載入引數可以參考apaszke推薦的做法,即刪除與當前model不匹配的key。程式碼片段為:

model = ...
model_dict = model.state_dict()

# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(model_dict)


二. 不同層設定不同學習率的方法 
此部分主要參考PyTorch教程的Autograd machnics部分 
2.1 在PyTorch中,每個Variable資料含有兩個flag(requires_grad和volatile)用於指示是否計算此Variable的梯度。設定requires_grad = False,或者設定volatile=True,即可指示不計算此Variable的梯度:

for param in model.parameters():
    param.requires_grad = False


注意,在模型測試時,對input_data設定volatile=True,可以節省測試時的視訊記憶體 
2.2 PyTorch的Module.modules()和Module.children() 

參考PyTorch document和discuss 
在PyTorch中,所有的neural network module都是class torch.nn.Module的子類,在Modules中可以包含其它的Modules,以一種樹狀結構進行巢狀。當需要返回神經網路中的各個模組時,Module.modules()方法返回網路中所有模組的一個iterator,而Module.children()方法返回所有直接子模組的一個iterator。具體而言:

list(nn.Sequential(nn.Linear(10, 20), nn.ReLU()).modules())
Out[9]:
[Sequential (
   (0): Linear (10 -> 20)
   (1): ReLU ()
 ), Linear (10 -> 20), ReLU ()]

In [10]: list(nn.Sequential(nn.Linear(10, 20), nn.ReLU()).children())
Out[10]: [Linear (10 -> 20), ReLU ()]


2.3 選擇特定的層進行finetune 
先使用Module.children()方法檢視網路的直接子模組,將不需要調整的模組中的引數設定為param.requires_grad = False,同時用一個list收集需要調整的模組中的引數。具體程式碼為:

count = 0
    para_optim = []
    for k in model.children():
        count += 1
        # 6 should be changed properly
        if count > 6:
            for param in k.parameters():
                para_optim.append(param)
        else:
            for param in k.parameters():
                param.requires_grad = False
optimizer = optim.RMSprop(para_optim, lr)



到此我們實現了PyTorch中使用預訓練的模型初始化網路的一部分引數,參考程式碼見我的GitHub
--------------------- 
作者:樂兮山南水北 
來源:CSDN 
原文:https://blog.csdn.net/u012494820/article/details/79068625 
版權宣告:本文為博主原創文章,轉載請附上博文連結!

有的時候我們需要對預訓練的模型增減一些網路層或著修改某些層的引數等

一、pytorch中的pre-train模型
卷積神經網路的訓練是耗時的,很多場合不可能每次都從隨機初始化引數開始訓練網路。
pytorch中自帶幾種常用的深度學習網路預訓練模型,如VGG、ResNet等。往往為了加快學習的進度,在訓練的初期我們直接載入pre-train模型中預先訓練好的引數,model的載入如下所示:

import torchvision.models as models
 
#resnet
model = models.ResNet(pretrained=True)
model = models.resnet18(pretrained=True)
model = models.resnet34(pretrained=True)
model = models.resnet50(pretrained=True)
 
#vgg
model = models.VGG(pretrained=True)
model = models.vgg11(pretrained=True)
model = models.vgg16(pretrained=True)
model = models.vgg16_bn(pretrained=True)


二、預訓練模型的修改
1.引數修改
對於簡單的引數修改,這裡以resnet預訓練模型舉例,resnet原始碼在Github點選開啟連結。
resnet網路最後一層分類層fc是對1000種類型進行劃分,對於自己的資料集,如果只有9類,修改的程式碼如下:

# coding=UTF-8
import torchvision.models as models
 
#呼叫模型
model = models.resnet50(pretrained=True)
#提取fc層中固定的引數
fc_features = model.fc.in_features
#修改類別為9
model.fc = nn.Linear(fc_features, 9)

2.增減卷積層
前一種方法只適用於簡單的引數修改,有的時候我們往往要修改網路中的層次結構,這時只能用引數覆蓋的方法,即自己先定義一個類似的網路,再將預訓練中的引數提取到自己的網路中來。這裡以resnet預訓練模型舉例。

# coding=UTF-8
import torchvision.models as models
import torch
import torch.nn as nn
import math
import torch.utils.model_zoo as model_zoo
 
class CNN(nn.Module):
 
    def __init__(self, block, layers, num_classes=9):
        self.inplanes = 64
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AvgPool2d(7, stride=1)
        #新增一個反捲積層
        self.convtranspose1 = nn.ConvTranspose2d(2048, 2048, kernel_size=3, stride=1, padding=1, output_padding=0, groups=1, bias=False, dilation=1)
        #新增一個最大池化層
        self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
        #去掉原來的fc層,新增一個fclass層
        self.fclass = nn.Linear(2048, num_classes)
 
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
 
    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )
 
        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))
 
        return nn.Sequential(*layers)
 
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
 
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
 
        x = self.avgpool(x)
        #新加層的forward
        x = x.view(x.size(0), -1)
        x = self.convtranspose1(x)
        x = self.maxpool2(x)
        x = x.view(x.size(0), -1)
        x = self.fclass(x)
 
        return x
 
#載入model
resnet50 = models.resnet50(pretrained=True)
cnn = CNN(Bottleneck, [3, 4, 6, 3])
#讀取引數
pretrained_dict = resnet50.state_dict()
model_dict = cnn.state_dict()
# 將pretrained_dict裡不屬於model_dict的鍵剔除掉
pretrained_dict =  {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 更新現有的model_dict
model_dict.update(pretrained_dict)
# 載入我們真正需要的state_dict
cnn.load_state_dict(model_dict)
# print(resnet50)
print(cnn)


--------------------- 
作者:whut_ldz 
來源:CSDN 
原文:https://blog.csdn.net/whut_ldz/article/details/78845947 
版權宣告:本文為博主原創文章,轉載請附上博文連結!