1. 程式人生 > >【小白學PyTorch】4 構建模型三要素與權重初始化

【小白學PyTorch】4 構建模型三要素與權重初始化

文章目錄: [TOC] ## 1 模型三要素 三要素其實很簡單 1. 必須要繼承nn.Module這個類,要讓PyTorch知道這個類是一個Module 2. 在\_\_init\_\_(self)中設定好需要的元件,比如conv,pooling,Linear,BatchNorm等等 3. 最後在forward(self,x)中用定義好的元件進行組裝,就像搭積木,把網路結構搭建出來,這樣一個模型就定義好了 我們來看一個例子: 先看\_\_init\_\_(self)函式 ```python def __init__(self): super(Net,self).__init__() self.conv1 = nn.Conv2d(3,6,5) self.pool1 = nn.MaxPool2d(2,2) self.conv2 = nn.Conv2d(6,16,5) self.pool2 = nn.MaxPool2d(2,2) self.fc1 = nn.Linear(16*5*5,120) self.fc2 = nn.Linear(120,84) self.fc3 = nn.Linear(84,10) ``` 第一行是初始化,往後定義了一系列元件。```nn.Conv2d```就是一般圖片處理的卷積模組,然後池化層,全連線層等等。 定義完這些定義forward函式 ```python def forward(self,x): x = self.pool1(F.relu(self.conv1(x))) x = self.pool2(F.relu(self.conv2(x))) x = x.view(-1,16*5*5) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x ``` x為模型的輸入,第一行表示x經過conv1,然後經過啟用函式relu,然後經過pool1操作 第三行表示對x進行reshape,為後面的全連線層做準備 至此,對一個模型的定義完畢,如何使用呢? 例如: ```python net = Net() outputs = net(inputs) ``` 其實```net(inputs)```,就是類似於使用了```net.forward(inputs)```這個函式。 ## 2 引數初始化 簡單地說就是**設定什麼層用什麼初始方法,初始化的方法會在torch.nn.init中** 話不多說,看一個案例: ```python # 定義權值初始化 def initialize_weights(self): for m in self.modules(): if isinstance(m,nn.Conv2d): torch.nn.init.xavier_normal_(m.weight.data) if m.bias is not None: m.bias.data.zero_() elif isinstance(m,nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() elif isinstance(m,nn.Linear): torch.nn.init.normal_(m.weight.data,0,0.01) # m.weight.data.normal_(0,0.01) m.bias.data.zero_() ``` 這段程式碼的基本流程就是,先從self.modules()中遍歷每一層,然後判斷更曾屬於什麼型別,是否是Conv2d,是否是BatchNorm2d,是否是Linear的,然後根據不同型別的層,設定不同的權值初始化方法,例如Xavier,kaiming,normal_等等。**kaiming也是MSRA初始化,是何愷明大佬在微軟亞洲研究院的時候,因此得名。** 上面程式碼中用到了```self.modules()```,這個是什麼東西呢? ```python # self.modules的原始碼 def modules(self): for name,module in self.named_modules(): yield module ``` 功能就是:能依次返回模型中的各層,```yield```是讓一個函式可以像迭代器一樣可以用for迴圈不斷從裡面遍歷(可能說的不太明確)。 ## 3 完整執行程式碼 我們用下面的例子來更深入的理解```self.modules()```,同時也把上面的內容都串起來(下面的程式碼塊可以執行): ```python import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset,DataLoader class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool1 = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.pool2 = nn.MaxPool2d(2, 2) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = self.pool1(F.relu(self.conv1(x))) x = self.pool2(F.relu(self.conv2(x))) x = x.view(-1, 16 * 5 * 5) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x def initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): torch.nn.init.xavier_normal_(m.weight.data) if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() elif isinstance(m, nn.Linear): torch.nn.init.normal_(m.weight.data, 0, 0.01) # m.weight.data.normal_(0,0.01) m.bias.data.zero_() net = Net() net.initialize_weights() print(net.modules()) for m in net.modules(): print(m) ``` 執行結果: ```python # 這個是print(net.modules())