1. 程式人生 > >基於Pytorch的卷積神經網路剪枝

基於Pytorch的卷積神經網路剪枝

       本篇部落格對網路剪枝的實現方法主要在https://jacobgil.github.io/deeplearning/pruning-deep-learning的基礎上進行了相應修改而完成,所參考的論文為https://arxiv.org/abs/1611.06440。本篇部落格所使用的程式碼見https://github.com/PolarisShi/purning

       網路剪枝個人覺得是一種實用性非常強的網路壓縮方法,並且可以和其它模型壓縮方法如網路蒸餾、引數位壓縮等進行組合,在保留網路識別精度的同時極大幅度的減少網路在使用時的計算量。但是非常令人困惑的是,這種簡單粗暴實用的方法,雖然在16年就已經提出了,在網上能夠找到的資料反而不多。根據jacobgil的分析,可能的原因有:1、目前對剪枝的評價方法(決定哪一些引數應該被刪除)還不夠完善。2、以目前的框架很難實現網路的剪枝。3、各路大神都把這類網路壓縮方法作為自己的大招祕而不宣。個人覺得,第2點才是主要原因。。。jacobgil大神采用python2+pytorch實現了對VGG16網路的壓縮,不過正因為演算法實現較為複雜,所以對於不同的網路結構,還是要對演算法做相應調整,不過只要理解了演算法修改起來還是很容易的。

       剪枝演算法的原理非常簡單,如論文中下圖所示:

                                                                             

       對於一個訓練完畢的網路,首先評價各神經元的重要性,把重要性最低神經元移除,之後微調網路,迴圈上述三個步驟直到網路達到預定目標。所以問題主要在於兩個部分,即如何評價各神經元重要性,以及怎麼實現移除神經元。

       考慮一組訓練集D(包含輸入X和輸出Y),網路的引數W(包含weight和bias)被優化以使得代價函式C(D|W)最小。那麼對神經元h重要性的判斷就是將神經元h置0,此時代價函式變為:C(D|W, h=0)。我們需要讓剪枝前後網路模型的代價函式儘可能的相似,因此我們實質上是依次找出使函式abs( C(D|W) - C(D|W, h=0) )最小的h,之後依次刪除。

       下面主要結合jacobgil大神的實現方法說一說怎麼實現對我們自己的模型的剪枝。jacobgil將在實現的過程中,把模型分為了兩個部分,featrues和classifier,這可能和pytorch自帶的VGG網路模型格式有關。features和classifier都是sequence格式,主要就是對featrues中各卷積層的filter進行修剪,同時修改下一層連線的卷積層或者全連線層的相應輸入。

       這裡可以基於我的github中的程式碼來看怎麼實現,程式碼均基於jacobgil的程式碼進行了一定的修改。總共有四個檔案,main.py,prune.py,dataset.py和observe.py。dataset用於匯入資料,prune主要是實現剪枝操作,main就是主函式,實現了網路模型搭建、訓練、測試、神經元重要性評價等等,最後observe用來對網路訓練結果在測試集上的表現進行觀察。

      原始碼中,main.py中的class CNN用於建立網路模型。由於prune.py檔案的限制,網路最好寫成如下形式。features和classifier都是nn.Sequential形式。如果採用其他結構的網路,比如在featrues中嵌套了多個子類nn.Sequential或者在classifier中不適用nn.Linear而是僅採用卷積結構和全域性池化層進行分類,則需要對prune.py中的相應部分進行修改。

out = self.features(x)
out = out.view(out.size(0), -1)
out = self.classifier(out)

      class FilterPrunner主要用於對網路進行剪枝部分的操作。compute_rank函式就是計算網路中各filter的重要性,這之後將重要性低的filter進行記錄,最後輸出filters_to_prune這個引數作為剪枝的依據。

      class PrunningFineTuner_CNN就是整個程式在執行時的主要部分,包含了網路模型的引數配置、訓練、測試以及剪枝的呼叫。在def prune(self)中,我們需要設定每一次迴圈需要剪枝的數目、最終網路經過剪枝後的filter保留率、每次剪枝之後的fine tune過程的引數等等。我們可以看到呼叫剪枝過程的核心語句:

         model = self.model.cpu()
            for layer_index, filter_index in prune_targets:
                
                model = prune_conv_layer(model, layer_index, filter_index)
                
            self.model = model.cuda()

      這裡是將目前的模型,需要剪枝的網路層數以及filter的編號這三個引數輸入prune.py中的prune_conv_layer函式,輸出經過剪枝之後的網路,迴圈直到目標filter全部被修剪完畢,最後將修剪完成的模型替換原有的模型。

      在prune.py中,我們可以看到,這種修剪實際上是通過重新定義一個卷積層,之後替換原有卷積層來實現的,這也是由於當前框架的一些限制所導致的。注意如果網路中存在非卷積的其他和卷積層輸入輸出相關聯的層如nn.BatchNorm2d等,也需要相應跟隨卷積層進行調整。而如果在定義網路的時候嵌套了多個sequence結構,那還得修改main中的部分程式碼以使得能夠定位到子sequence中的卷積層才行。

      我個人還是建議大家基於jacobgil的原始碼,針對自己想要剪枝的神經網路,進行相應的修改,來加深對這種方法的理解。我進行實驗時,建立的CNN結構如下:

CNN(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU(inplace)
    (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ReLU(inplace)
    (13): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (14): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (16): ReLU(inplace)
    (17): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (19): ReLU(inplace)
    (20): AvgPool2d(kernel_size=8, stride=8, padding=0)
  )
  (classifier): Sequential(
    (0): Linear(in_features=256, out_features=10, bias=True)
  )
)

      每一批剪枝的數目為32個filter,因為是實驗所以這樣設定一個恆定引數無妨,但是在實際使用的時候,最好還是將這個引數設定為一個與當前網路總filter成比例的一個數,來使得在剪枝後期網路模型較小的時候,網路一次性失去的filter不至於過多,以避免對正確率造成較大影響。

      我們設定的剪枝目標是原始網路filter數目的1/10,就是剪去90%的filter。總共經過25次修剪之後fine tune操作,歷代修剪的引數數量如下:

{0: 8, 3: 3, 10: 3, 14: 10, 17: 6, 7: 2}
{14: 8, 17: 14, 7: 4, 0: 1, 10: 5}
{17: 15, 14: 14, 7: 2, 0: 1}
{17: 10, 7: 4, 14: 10, 10: 6, 3: 2}
{10: 4, 14: 13, 7: 7, 17: 6, 3: 1, 0: 1}
{17: 15, 14: 6, 7: 3, 3: 1, 10: 4, 0: 3}
{17: 11, 7: 5, 0: 2, 14: 10, 3: 1, 10: 3}
{14: 12, 17: 12, 3: 2, 7: 3, 10: 3}
{17: 8, 7: 5, 14: 10, 3: 1, 10: 6, 0: 2}
{17: 17, 10: 6, 14: 4, 3: 1, 7: 4}
{17: 12, 7: 7, 14: 9, 10: 2, 0: 2}
{17: 4, 14: 12, 10: 8, 7: 2, 3: 4, 0: 2}
{0: 4, 14: 10, 17: 9, 10: 5, 7: 4}
{17: 11, 14: 14, 10: 2, 3: 2, 7: 3}
{17: 10, 10: 7, 14: 13, 3: 2}
{7: 8, 14: 13, 17: 6, 10: 4, 3: 1}
{14: 8, 17: 9, 10: 5, 7: 7, 3: 2, 0: 1}
{14: 10, 17: 11, 3: 2, 0: 2, 7: 3, 10: 4}
{17: 13, 3: 2, 14: 9, 7: 4, 10: 2, 0: 2}
{7: 10, 17: 8, 14: 6, 3: 2, 10: 5, 0: 1}
{17: 4, 14: 10, 3: 4, 10: 6, 7: 7, 0: 1}
{3: 2, 14: 9, 10: 4, 7: 6, 0: 6, 17: 5}
{7: 3, 14: 11, 10: 8, 17: 6, 3: 4}
{14: 8, 0: 2, 7: 4, 3: 5, 17: 7, 10: 6}
{14: 2, 7: 5, 0: 8, 10: 6, 3: 8, 17: 3}

      而各代的正確率和與原始網路大小的關係如下圖所示,圖中橫座標表示剪枝數目佔原始網路的百分比,縱座標:

                                                  

      而從圖中可以看出,隨著網路剪枝數目的增加,網路的準確率逐步下降,並且下降的越來越快。。。我覺得可能有兩個方面的原因:首先是網路本身的容量不足,其次則是固定剪枝數目而導致對最後幾代網路造成了不可逆的影響。