1. 程式人生 > >[PyTorch 學習筆記] 5.2 Hook 函式與 CAM 演算法

[PyTorch 學習筆記] 5.2 Hook 函式與 CAM 演算法

> 本章程式碼: > > - [https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson5/hook_fmap_vis.py](https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson5/hook_fmap_vis.py) > - [https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson5/hook_methods.py](https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson5/hook_methods.py) > - [https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson5/weight_fmap_visualization.py](https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson5/weight_fmap_visualization.py) 這篇文章主要介紹瞭如何使用 Hook 函式提取網路中的特徵圖進行視覺化,和 CAM(class activation map, 類啟用圖) # Hook 函式概念 Hook 函式是在不改變主體的情況下,實現額外功能。由於 PyTorch 是基於動態圖實現的,因此在一次迭代運算結束後,一些中間變數如非葉子節點的梯度和特徵圖,會被釋放掉。在這種情況下想要提取和記錄這些中間變數,就需要使用 Hook 函式。 PyTorch 提供了 4 種 Hook 函式。 ## torch.Tensor.register_hook(hook) 功能:註冊一個反向傳播 hook 函式,僅輸入一個引數,為張量的梯度。 `hook`函式: ``` hook(grad) ``` 引數: - grad:張量的梯度 程式碼如下: ``` w = torch.tensor([1.], requires_grad=True) x = torch.tensor([2.], requires_grad=True) a = torch.add(w, x) b = torch.add(w, 1) y = torch.mul(a, b) # 儲存梯度的 list a_grad = list() # 定義 hook 函式,把梯度新增到 list 中 def grad_hook(grad): a_grad.append(grad) # 一個張量註冊 hook 函式 handle = a.register_hook(grad_hook) y.backward() # 檢視梯度 print("gradient:", w.grad, x.grad, a.grad, b.grad, y.grad) # 檢視在 hook 函式裡 list 記錄的梯度 print("a_grad[0]: ", a_grad[0]) handle.remove() ``` 結果如下: ``` gradient: tensor([5.]) tensor([2.]) None None None a_grad[0]: tensor([2.]) ``` 在反向傳播結束後,非葉子節點張量的梯度被清空了。而通過`hook`函式記錄的梯度仍然可以檢視。 `hook`函式裡面可以修改梯度的值,無需返回也可以作為新的梯度賦值給原來的梯度。程式碼如下: ``` w = torch.tensor([1.], requires_grad=True) x = torch.tensor([2.], requires_grad=True) a = torch.add(w, x) b = torch.add(w, 1) y = torch.mul(a, b) a_grad = list() def grad_hook(grad): grad *= 2 return grad*3 handle = w.register_hook(grad_hook) y.backward() # 檢視梯度 print("w.grad: ", w.grad) handle.remove() ``` 結果是: ``` w.grad: tensor([30.]) ``` ## torch.nn.Module.register_forward_hook(hook) 功能:註冊 module 的前向傳播`hook`函式,可用於獲取中間的 feature map。 `hook`函式: ``` hook(module, input, output) ``` 引數: - module:當前網路層 - input:當前網路層輸入資料 - output:當前網路層輸出資料 下面程式碼執行的功能是 $3 \times 3$ 的卷積和 $2 \times 2$ 的池化。我們使用`register_forward_hook()`記錄中間卷積層輸入和輸出的 feature map。
``` class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 2, 3) self.pool1 = nn.MaxPool2d(2, 2) def forward(self, x): x = self.conv1(x) x = self.pool1(x) return x def forward_hook(module, data_input, data_output): fmap_block.append(data_output) input_block.append(data_input) # 初始化網路 net = Net() net.conv1.weight[0].detach().fill_(1) net.conv1.weight[1].detach().fill_(2) net.conv1.bias.data.detach().zero_() # 註冊hook fmap_block = list() input_block = list() net.conv1.register_forward_hook(forward_hook) # inference fake_img = torch.ones((1, 1, 4, 4)) # batch size * channel * H * W output = net(fake_img) # 觀察 print("output shape: {}\noutput value: {}\n".format(output.shape, output)) print("feature maps shape: {}\noutput value: {}\n".format(fmap_block[0].shape, fmap_block[0])) print("input shape: {}\ninput value: {}".format(input_block[0][0].shape, input_block[0])) ``` 輸出如下: ``` output shape: torch.Size([1, 2, 1, 1]) output value: tensor([[[[ 9.]], [[18.]]]],