1. 程式人生 > >【小白學PyTorch】6 模型的構建訪問遍歷儲存(附程式碼)

【小白學PyTorch】6 模型的構建訪問遍歷儲存(附程式碼)

文章轉載自微信公眾號:機器學習煉丹術。歡迎大家關注,這是我的學習分享公眾號,100+原創乾貨。 文章目錄: [TOC] 本文是對一些函式的學習。函式主要包括下面四個方便: - 模型構建的函式:```add_module```,```add_module```,```add_module``` - 訪問子模組:```add_module```,```add_module```,```add_module```,```add_module``` - 網路遍歷: ```add_module```,```add_module``` - 模型的儲存與載入:```add_module```,```add_module```,```add_module``` ## 1 模型構建函式 ```torch.nn.Module```是所有網路的基類,在PyTorch實現模型的類中都要繼承這個類(這個在之前的課程中已經提到)。在構建Module中,Module是一個包含其他的Module的,類似於,你可以先定義一個小的網路模組,然後把這個小模組作為另外一個網路的元件。**因此網路結構是呈現樹狀結構**。 我們先簡單定義一個網路: ```python import torch.nn as nn import torch class MyNet(nn.Module): def __init__(self): super(MyNet,self).__init__() self.conv1 = nn.Conv2d(3,64,3) self.conv2 = nn.Conv2d(64,64,3) def forward(self,x): x = self.conv1(x) x = self.conv2(x) return x net = MyNet() print(net) ``` 輸出結果: ![](https://p3-juejin.byteimg.com/tos-cn-i-k3u1fbpfcp/a4e913dcfa5344288740a0c317152ce5~tplv-k3u1fbpfcp-zoom-1.image) ```MyNet```中有兩個屬性```conv1```和```conv2```是兩個卷積層,在正向傳播```forward```的過程中,依次呼叫這兩個卷積層實現網路的功能。 ### 1.1 add_module 這種是最常見的定義網路的功能,在有些專案中,會看到這樣的方法```add_module```。我們用這個方法來重寫上面的網路: ```python class MyNet(nn.Module): def __init__(self): super(MyNet,self).__init__() self.add_module('conv1',nn.Conv2d(3,64,3)) self.add_module('conv2',nn.Conv2d(64,64,3)) def forward(self,x): x = self.conv1(x) x = self.conv2(x) return x ``` 其實```add_module(name,layer)```和```self.name=layer```實現了相同的功能,**個人感覺也許是因為add_module可以使用字串來定義變數名字,所以可以放在迴圈中?反正這個先了解熟悉熟悉**。 上面的兩種方法都是一層一層的新增layer,如果網路複雜的話,那就需要寫很多重複的程式碼了。因此接下來來講解一下網路模組的構建,```torch.nn.ModuleList```和```torch.nn.Sequential``` ### 1.2 ModuleList ```ModuleList```按照字面意思是用```list```的形式儲存網路層的。這樣就可以先將網路需要的layer構建好,儲存到一個list,然後通過```ModuleList```方法新增到網路中. ```python class MyNet(nn.Module): def __init__(self): super(MyNet,self).__init__() self.linears = nn.ModuleList( [nn.Linear(10,10) for i in range(5)] ) def forward(self,x): for l in self.linears: x = l(x) return x net = MyNet() print(net) ``` 輸出結果是: ![](https://p9-juejin.byteimg.com/tos-cn-i-k3u1fbpfcp/d24a5278ecb545089ad9149265b8fa50~tplv-k3u1fbpfcp-zoom-1.image) 這個ModuleList主要是用在讀取config檔案來構建網路模型中的,下面用VGG模型的構建為例子: ```python vgg_cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'C', 512, 512, 512, 'M', 512, 512, 512, 'M'] def vgg(cfg, i, batch_norm=False): layers = [] in_channels = i for v in cfg: if v == 'M': layers += [nn.MaxPool2d(kernel_size=2, stride=2)] elif v == 'C': layers += [nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)] else: conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) if batch_norm: layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] else: layers += [conv2d, nn.ReLU(inplace=True)] in_channels = v return layers class Model1(nn.Module): def __init__(self): super(Model1,self).__init__() self.vgg = nn.ModuleList(vgg(vgg_cfg,3)) def forward(self,x): for l in self.vgg: x = l(x) m1 = Model1() print(m1) ``` ![](https://p6-juejin.byteimg.com/tos-cn-i-k3u1fbpfcp/0c7863120c57409b880672c0a1e6c3c5~tplv-k3u1fbpfcp-zoom-1.image) 先讀取網路結構的配置檔案```vgg_cfg```然後根據這個檔案建立對應的Layer list,然後使用```ModuleList```新增到網路中,這樣可以快速建立不同的網路(**用上面為例子的話,可以通過修改配置檔案,然後快速修改網路結構** ) ### 1.3 Sequential 在一些自己做的小專案中,```Sequential```其實用的更為頻繁。 依然重寫最初最簡單的例子: ```python class MyNet(nn.Module): def __init__(self): super(MyNet,self).__init__() self.conv = nn.Sequential( nn.Conv2d(3,64,3), nn.Conv2d(64,64,3) ) def forward(self,x): x = self.conv(x) return x net = MyNet() print(net) ``` 執行結果: ![](https://p9-juejin.byteimg.com/tos-cn-i-k3u1fbpfcp/72b281e6505642e196abb79ce1bc98aa~tplv-k3u1fbpfcp-zoom-1.image) 觀察細緻的朋友可以發現這個問題,Seqential內的網路層是預設用數字進行標號的,而一開始我們使用```self.conv1```和```self.conv2```的時候,使用conv1和conv2作為標號的。 ![](https://p1-juejin.byteimg.com/tos-cn-i-k3u1fbpfcp/8ba49319883a43b996bea0fde6dcdcb3~tplv-k3u1fbpfcp-zoom-1.image) 我們如何修改```Sequential```中網路層的名稱呢?這裡需要使用到```collections.OrderedDict```有序字典。```Sequential```是支援有序字典構建的。 ```python from collections import OrderedDict class MyNet(nn.Module): def __init__(self): super(MyNet,self).__init__() self.conv = nn.Sequential(OrderedDict([ ('conv1',nn.Conv2d(3,64,3)), ('conv2',nn.Conv2d(64,64,3)) ])) def forward(self,x): x = self.conv(x) return x net = MyNet() print(net) ``` 輸出結果: ![](https://p9-juejin.byteimg.com/tos-cn-i-k3u1fbpfcp/9b76cd12691e44e3bbab27aab47f1ff0~tplv-k3u1fbpfcp-zoom-1.image) ### 1.4 小總結 - 單獨增加一個網路層或者子模組,可以用```add_module```或者直接賦予屬性; - ```ModuleList```可以將一個Module的List增加到網路中,自由度較高。 - ```Sequential```按照順序產生一個Module模組。**這裡推薦習慣使用OrderedDict的方法進行構建。對網路層加上規範的名稱,這樣有助於後續查詢與遍歷** ## 2 遍歷模型結構 本章節使用下面的方法進行遍歷之前提到的```Module```。(**個人理解,Module是多個layer的合併,但是一個layer可以說成Module。** ) 先定義一個網路吧,隨便寫一個: ```python import torch.nn as nn import torch from collections import OrderedDict class MyNet(nn.Module): def __init__(self): super(MyNet,self).__init__() self.conv1 = nn.Conv2d(in_channels=3,out_channels=64,kernel_size=3) self.conv2 = nn.Conv2d(64,64,3) self.maxpool1 = nn.MaxPool2d(2,2) self.features = nn.Sequential(OrderedDict([ ('conv3', nn.Conv2d(64,128,3)), ('conv4', nn.Conv2d(128,128,3)), ('relu1', nn.ReLU()) ])) def forward(self,x): x = self.conv1(x) x = self.conv2(x) x = self.maxpool1(x) x = self.features(x) return x net = MyNet() print(net) ``` 輸出結果是: ![](https://p3-juejin.byteimg.com/tos-cn-i-k3u1fbpfcp/9f113e563ef146a2b2e79fd7229e235e~tplv-k3u1fbpfcp-zoom-1.image) ### 2.1 modules() 在第四課中初始化模型各個層的引數的時候,用到了這個方法,現在我們再來理解一下: ```python for idx,m in enumerate(net.modules()): print(idx,"-",m) ``` 執行結果: ![](https://p1-juejin.byteimg.com/tos-cn-i-k3u1fbpfcp/153f6b96774a4ffb9ab0d2d8a42f30fa~tplv-k3u1fbpfcp-zoom-1.image) 上面那個網路構建的時候用到了```Sequential```,所以網路中其實是嵌套了一個小的Module,這就是之前提到的樹狀結構,然後上面便利的時候也是樹狀結構的便利過程,可以看出來應該**是一個深度遍歷的過程。** - 首先第一個輸出的是最大的那個Module,也就是整個網路,```0-Model```整個網路模組; - ```1-2-3-4```是網路的四個子模組,```4-Sequential```中間仍然包含子模組 - ```5-6-7```是模組```4-Sequential```的子模組。 **【總結】** ```modules()```是遞迴的返回網路的各個module(深度遍歷),從最頂層直到最後的葉子的module。 ### 2.2 named_modules() ```named_modules()```和```module()```類似,只是同時返回name和module。 ```python for idx,(name,m) in enumerate(net.named_modules()): print(idx,"-",name) ``` 輸出結果: ![](https://p1-juejin.byteimg.com/tos-cn-i-k3u1fbpfcp/ed0dba636bca4bbf9c6b0ddd55d0bf35~tplv-k3u1fbpfcp-zoom-1.image) ### 2.3 parameters() ```python for p in net.parameters(): print(type(p.data),p.size()) ``` 執行結果: ![](https://p3-juejin.byteimg.com/tos-cn-i-k3u1fbpfcp/e31182d66b4a4867847e1be0c21b0351~tplv-k3u1fbpfcp-zoom-1.image) 輸出的是四個卷積層的權重矩陣引數和偏置引數。值得一提的是,**對網路進行訓練時需要將parameters()作為優化器optimizer的引數。** ```python optimizer = torch.optim.SGD(net.parameters(), lr = 0.001, momentum=0.9) ``` 總之呢,這個```parameters()```是返回網路所有的引數,主要用在給optimizer優化器用的。而要對網路的某一層的引數做處理的時候,一般還是使用named_parameters()方便一些。 ```python for idx,(name,m) in enumerate(net.named_parameters()): print(idx,"-",name,m.size()) ``` 輸出結果: ![](https://p3-juejin.byteimg.com/tos-cn-i-k3u1fbpfcp/4c74362e57f043ca84cd55893876d2dc~tplv-k3u1fbpfcp-zoom-1.image) **【小擴充套件】** 我個人有時會使用下面的方法來獲取引數: ```python for idx,(name,m) in enumerate(net.named_modules()): if isinstance(m,nn.Conv2d): print(m.weight.shape) print(m.bias.shape) ``` 先判斷是否是卷積層,然後獲取其引數,輸出結果: ![](https://p3-juejin.byteimg.com/tos-cn-i-k3u1fbpfcp/2966ed628b024a4c912e77589510e54e~tplv-k3u1fbpfcp-zoom-1.image) ## 3 儲存與載入 PyTorch使用```torch.save```和```torch.load```方法來儲存和載入網路,而且網路結構和引數可以分開的儲存和載入。 ```python torch.save(model,'model.pth') # 儲存 model = torch.load("model.pth") # 載入 ``` pytorch中**網路結構和模型引數是可以分開儲存的**。上面的方法是兩者同時儲存到了.pth檔案中,當然,你也可以僅僅儲存網路的引數來減小儲存檔案的大小。**注意:如果你僅僅儲存模型引數,那麼在載入的時候,是需要通過執行程式碼來初始化模型的結構的。** ```python torch.save(model.state_dict(),"model.pth") # 儲存引數 model = MyNet() # 程式碼中建立網路結構 params = torch.load("model.pth") # 載入引數 model.load_state_dict(params) # 應用到網路結構中 ``` 至此,我們今天已經學習了不少的內容,大家對PyTorch的掌握更近一