1. 程式人生 > >pytorch學習筆記(三):自動求導

pytorch學習筆記(三):自動求導

Backward過程中排除子圖

pytorchBP過程是由一個函式決定的,loss.backward(), 可以看到backward()函式裡並沒有傳要求誰的梯度。那麼我們可以大膽猜測,在BP的過程中,pytorch是將所有影響lossVariable都求了一次梯度。但是有時候,我們並不想求所有Variable的梯度。那就要考慮如何在Backward過程中排除子圖(ie.排除沒必要的梯度計算)。
如何BP過程中排除子圖? Variable的兩個引數(requires_gradvolatile

requires_grad:

import torch
from torch.autograd import
Variable x = Variable(torch.randn(5, 5)) y = Variable(torch.randn(5, 5)) z = Variable(torch.randn(5, 5), requires_grad=True) a = x + y # x, y的 requires_grad的標記都為false, 所以輸出的變數requires_grad也為false a.requires_grad
False
b = a + z #a ,z 中,有一個 requires_grad 的標記為True,那麼輸出的變數的 requires_grad為True
b.requires_grad
True

變數的requires_grad標記的運算就相當於or
如果你想部分凍結你的網路(ie.不做梯度計算),那麼通過設定requires_grad標籤是非常容易實現的。
下面給出了利用requires_grad使用pretrained網路的一個例子,只fine tune了最後一層。

model = torchvision.models.resnet18(pretrained=True)
for param in model.parameters():
    param.requires_grad = False
# Replace the last fully-connected layer
# Parameters of newly constructed modules have requires_grad=True by default model.fc = nn.Linear(512, 100) # Optimize only the classifier optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)

volatile

j = Variable(torch.randn(5,5), volatile=True)
k = Variable(torch.randn(5,5))
m = Variable(torch.randn(5,5))
n = k+m # k,m變數的volatile標記都為False,輸出的Variable的volatile標記也為false
n.volatile
False
o = j+k #k,m變數的volatile標記有一個True,輸出的Variable的volatile為True
o.volatile
True

變數的volatile標記的運算也相當於or
注意:volatile=True相當於requires_grad=False。但是在純推斷模式的時候,只要是輸入volatile=True,那麼輸出Variable的volatile必為True。這就比使用requires_grad=False方便多了。

NOTE在使用volatile=True的時候,變數是不儲存 creator屬性的,這樣也減少了記憶體的使用。

為什麼要排除子圖

也許有人會問,梯度全部計算,不更新的話不就得了。
這樣就涉及了效率的問題了,計算很多沒用的梯度是浪費了很多資源的(時間,計算機記憶體)

參考資料