pytorch筆記:06)requires_grad和volatile
阿新 • • 發佈:2018-11-10
requires_grad
Variable變數的requires_grad的屬性預設為False,若一個節點requires_grad被設定為True,那麼所有依賴它的節點的requires_grad都為True。
x=Variable(torch.ones(1))
w=Variable(torch.ones(1),requires_grad=True)
y=x*w
x.requires_grad,w.requires_grad,y.requires_grad
Out[23]: (False, True, True)
y依賴於w,w的requires_grad=True,因此y的requires_grad=True (類似or操作)
volatile
volatile=True是Variable的另一個重要的標識,它能夠將所有依賴它的節點全部設為volatile=True,其優先順序比requires_grad=True高。因而volatile=True的節點不會求導,即使requires_grad=True,也不會進行反向傳播,對於不需要反向傳播的情景(inference,測試推斷),該引數可以實現一定速度的提升,並節省一半的視訊記憶體,因為其不需要儲存梯度
前方高能預警:如果你看完了前面volatile,請及時把它從你的腦海中擦除掉,因為
UserWarning: volatile was removed (Variable.volatile is always False)
該屬性已經在0.4版本中被移除了,並提示你可以使用with torch.no_grad()代替該功能
>>> x = torch.tensor([1], requires_grad=True)
>>> with torch.no_grad():
... y = x * 2
>>> y.requires_grad
False
>>> @torch.no_grad()
... def doubler(x):
... return x * 2
>>> z = doubler(x)
>>> z.requires_grad
False
https://pytorch.org/docs/master/autograd.html#locally-disable-grad