1. 程式人生 > >pytorch筆記:06)requires_grad和volatile

pytorch筆記:06)requires_grad和volatile

requires_grad
Variable變數的requires_grad的屬性預設為False,若一個節點requires_grad被設定為True,那麼所有依賴它的節點的requires_grad都為True。

x=Variable(torch.ones(1))
w=Variable(torch.ones(1),requires_grad=True)
y=x*w
x.requires_grad,w.requires_grad,y.requires_grad
Out[23]: (False, True, True)

y依賴於w,w的requires_grad=True,因此y的requires_grad=True (類似or操作)

volatile
volatile=True是Variable的另一個重要的標識,它能夠將所有依賴它的節點全部設為volatile=True,其優先順序比requires_grad=True高。因而volatile=True的節點不會求導,即使requires_grad=True,也不會進行反向傳播,對於不需要反向傳播的情景(inference,測試推斷),該引數可以實現一定速度的提升,並節省一半的視訊記憶體,因為其不需要儲存梯度

前方高能預警:如果你看完了前面volatile,請及時把它從你的腦海中擦除掉,因為

UserWarning: volatile was removed (Variable.volatile
is always False)

該屬性已經在0.4版本中被移除了,並提示你可以使用with torch.no_grad()代替該功能

>>> x = torch.tensor([1], requires_grad=True)
>>> with torch.no_grad():
...   y = x * 2
>>> y.requires_grad
False
>>> @torch.no_grad()
... def doubler(x):
...     return x * 2
>>> z = doubler(x)
>>> 
z.requires_grad False

https://pytorch.org/docs/master/autograd.html#locally-disable-grad