1. 程式人生 > >pytorch學習筆記之載入預訓練模型

pytorch學習筆記之載入預訓練模型

原文:https://blog.csdn.net/weixin_41278720/article/details/80759933 

pytorch自發布以來,由於其便捷性,贏得了越來越多人的喜愛。

Pytorch有很多方便易用的包,今天要談的是torchvision包,它包括3個子包,分別是: torchvison.datasets ,torchvision.models ,torchvision.transforms ,分別是預定義好的資料集(比如MNIST、CIFAR10等)、預定義好的經典網路結構(比如AlexNet、VGG、ResNet等)和預定義好的資料增強方法(比如Resize、ToTensor等)。這些方法可以直接呼叫,簡化我們建模的過程,也可以作為我們學習或構建新的模型的參考。

本文,我們講述的是models,且只談模型的載入。models這個包中包含alexnet、densenet、inception、resnet、squeezenet、vgg等常用的網路結構,並且提供了預訓練模型,可以通過簡單呼叫來讀取網路結構和預訓練模型。

模型地址:https://github.com/pytorch/vision/tree/master/torchvision/models

官方文件:https://pytorch.org/docs/master/torchvision/models.html

我將載入的方法簡單總結為以下四種:

1.直接載入預訓練模型 import torchvision.models as models

resnet50 = models.resnet50(pretrained=True) 這樣就匯入了resnet50的預訓練模型了。

如果只需要網路結構,不需要用預訓練模型的引數來初始化,那麼就是:

model =torchvision.models.resnet50(pretrained=False) 或者把resnet複製到自己的目錄下,新建個model資料夾

可以參考下面的貓狗大戰入門演算法入門

https://github.com/JackwithWilshere/Kaggle-Dogs_vs_Cats_PyTorch

2.修改某一層  以resnet為例,預設的是ImageNet的1000類,比如我們要做二分類,分類貓和狗

resnet.fc = nn.Linear(2048, 2)  resnet 第一層卷積的卷積核是7,我們可能想改成5,那麼可以通過以下方法修改:

#未經試驗,修改需要有理論依據,計算featuremap維度使之匹配。 resnet.conv1 = nn.Conv2d(3, 64,kernel_size=5, stride=2, padding=3, bias=False) 3.載入部分預訓練模型 對於具體的任務,很難保證模型和公開的模型完全一樣,但是預訓練模型的引數確實有助於提高訓練的準確率,為了結合二者的優點,就需要我們載入部分預訓練模型。

#載入model,model是自己定義好的模型 resnet50 = models.resnet50(pretrained=True)  model =Net(...) 

#讀取引數  pretrained_dict =resnet50.state_dict()  model_dict = model.state_dict() 

#將pretrained_dict裡不屬於model_dict的鍵剔除掉  pretrained_dict =  {k: v for k, v in pretrained_dict.items() if k in model_dict} 

# 更新現有的model_dict  model_dict.update(pretrained_dict) 

# 載入我們真正需要的state_dict  model.load_state_dict(model_dict)   4. 載入自己的模型 其實這個是儲存和恢復模型,比如我們訓練好的模型儲存,然後載入用於測試。

方法一(推薦):

第一種方法也是官方推薦的方法,只儲存和恢復模型中的引數。

使用這種方法,我們需要自己匯入模型的結構資訊。

(1)儲存

torch.save(model.state_dict(), PATH)

#example torch.save(resnet50.state_dict(),'ckp/model.pth')     (2)恢復

model = ModelClass(*args, **kwargs) model.load_state_dict(torch.load(PATH))

#example resnet=resnet50(pretrained=True) resnet.load_state_dict(torch.load('ckp/model.pth')) 方法二:

使用這種方法,將會儲存模型的引數和結構資訊。

(1)儲存

torch.save (model, PATH) (2)恢復

model = torch.load(PATH) 參考資料:

1. https://zhuanlan.zhihu.com/p/25980324

2. http://www.pytorchtutorial.com/pytorch-note5-save-and-restore-models/ ---------------------  作者:spectre7  來源:CSDN  原文:https://blog.csdn.net/weixin_41278720/article/details/80759933  版權宣告:本文為博主原創文章,轉載請附上博文連結!