pytorch學習筆記(三):自動求導
阿新 • • 發佈:2019-02-14
Backward過程中排除子圖
pytorch
的BP
過程是由一個函式決定的,loss.backward()
, 可以看到backward()
函式裡並沒有傳要求誰的梯度。那麼我們可以大膽猜測,在BP
的過程中,pytorch
是將所有影響loss
的Variable
都求了一次梯度。但是有時候,我們並不想求所有Variable
的梯度。那就要考慮如何在Backward過程中排除子圖
(ie.排除沒必要的梯度計算)。
如何BP
過程中排除子圖? Variable
的兩個引數(requires_grad
和volatile
)
requires_grad
:
import torch
from torch.autograd import Variable
x = Variable(torch.randn(5, 5))
y = Variable(torch.randn(5, 5))
z = Variable(torch.randn(5, 5), requires_grad=True)
a = x + y # x, y的 requires_grad的標記都為false, 所以輸出的變數requires_grad也為false
a.requires_grad
False
b = a + z #a ,z 中,有一個 requires_grad 的標記為True,那麼輸出的變數的 requires_grad為True
b.requires_grad
True
變數的requires_grad
標記的運算就相當於or
。
如果你想部分凍結你的網路(ie.不做梯度計算),那麼通過設定requires_grad
標籤是非常容易實現的。
下面給出了利用requires_grad
使用pretrained
網路的一個例子,只fine tune
了最後一層。
model = torchvision.models.resnet18(pretrained=True)
for param in model.parameters():
param.requires_grad = False
# Replace the last fully-connected layer
# Parameters of newly constructed modules have requires_grad=True by default
model.fc = nn.Linear(512, 100)
# Optimize only the classifier
optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)
volatile
:
j = Variable(torch.randn(5,5), volatile=True)
k = Variable(torch.randn(5,5))
m = Variable(torch.randn(5,5))
n = k+m # k,m變數的volatile標記都為False,輸出的Variable的volatile標記也為false
n.volatile
False
o = j+k #k,m變數的volatile標記有一個True,輸出的Variable的volatile為True
o.volatile
True
變數的volatile
標記的運算也相當於or
。
注意:volatile=True
相當於requires_grad=False
。但是在純推斷模式的時候,只要是輸入volatile=True
,那麼輸出Variable的volatile
必為True
。這就比使用requires_grad=False
方便多了。
NOTE
:在使用volatile=True
的時候,變數是不儲存 creator
屬性的,這樣也減少了記憶體的使用。
為什麼要排除子圖
也許有人會問,梯度全部計算,不更新的話不就得了。
這樣就涉及了效率的問題了,計算很多沒用的梯度是浪費了很多資源的(時間,計算機記憶體)