1. 程式人生 > >深度可分離卷積結構(depthwise separable convolution)計算複雜度分析

深度可分離卷積結構(depthwise separable convolution)計算複雜度分析

這個例子說明了什麼叫做空間可分離卷積,這種方法並不應用在深度學習中,只是用來幫你理解這種結構。

在神經網路中,我們通常會使用深度可分離卷積結構(depthwise separable convolution)。

這種方法在保持通道分離的前提下,接上一個深度卷積結構,即可實現空間卷積。接下來通過一個例子讓大家更好地理解。

假設有一個3×3大小的卷積層,其輸入通道為16、輸出通道為32。具體為,32個3×3大小的卷積核會遍歷16個通道中的每個資料,從而產生16×32=512個特徵圖譜。進而通過疊加每個輸入通道對應的特徵圖譜後融合得到1個特徵圖譜。最後可得到所需的32個輸出通道。

針對這個例子應用深度可分離卷積,用16個3×3大小的卷積核分別遍歷16通道的資料,得到了16個特徵圖譜。在融合操作之前,接著用32個1×1大小的卷積核遍歷這16個特徵圖譜,進行相加融合。這個過程使用了16×3×3+16×32×1×1=656個引數,遠少於上面的16×32×3×3=4608個引數。

這個例子就是深度可分離卷積的具體操作,其中上面的深度乘數(depth multiplier)設為1,這也是目前這類網路層的通用引數。

這麼做是為了對空間資訊和深度資訊進行去耦。從Xception模型的效果可以看出,這種方法是比較有效的。由於能夠有效利用引數,因此深度可分離卷積也可以用於移動裝置中。

src convolution

  input                              output

M*N*Cin                      M*N*Cout

             16*3*3*32

depthwise separable convolution

  input                       output1                          output2

M*N*Cin                 M*N*Cin                       M*N*Cout

               16*3*3                       16*32*1*1

另外一個地方看到的解釋:

MobileNet-v1:

MobileNet主要用於移動端計算模型,是將傳統的卷積操作改為兩層的卷積操作,在保證準確率的條件下,計算時間減少為原來的1/9,計算引數減少為原來的1/7.

MobileNet模型的核心就是將原本標準的卷積操作因式分解成一個depthwise convolution和一個1*1的pointwise convolution操作。簡單講就是將原來一個卷積層分成兩個卷積層,其中前面一個卷積層的每個filter都只跟input的每個channel進行卷積,然後後面一個卷積層則負責combining,即將上一層卷積的結果進行合併。 

depthwise convolution:

比如輸入的圖片是Dk*Dk*M(Dk是圖片大小,M是輸入的渠道數),那麼有M個Dw*Dw的卷積核,分別去跟M個渠道進行卷積,輸出Df*Df*M結果

pointwise convolution:

對Df*Df*M進行卷積合併,有1*1*N的卷積,進行合併常規的卷積,輸出Df*Df*N的結果

上面經過這兩個卷積操作,從一個Dk*Dk*M=>Df*Df*N,相當於用Dw*Dw*N的卷積核進行常規卷積的結果,但計算量從原來的DF*DF*DK*DK*M*N減少為DF*DF*DK*DK*M+DF*DF*M*N.

第一層為常規卷積,後面接著都為depthwise convolution+pointwise convolution,最後兩層為Pool層和全連線層,總共28層.

下面的程式碼是mobilenet的一個引數列表,計算的普通卷積與深度分離卷積的計算複雜程度比較

# Tensorflow mandates these.
from collections import namedtuple
import functools

import tensorflow as tf

slim = tf.contrib.slim

# Conv and DepthSepConv namedtuple define layers of the MobileNet architecture
# Conv defines 3x3 convolution layers
# DepthSepConv defines 3x3 depthwise convolution followed by 1x1 convolution.
# stride is the stride of the convolution
# depth is the number of channels or filters in a layer
Conv = namedtuple('Conv', ['kernel', 'stride', 'depth'])
DepthSepConv = namedtuple('DepthSepConv', ['kernel', 'stride', 'depth'])

# _CONV_DEFS specifies the MobileNet body
_CONV_DEFS = [
    Conv(kernel=[3, 3], stride=2, depth=32),
    DepthSepConv(kernel=[3, 3], stride=1, depth=64),
    DepthSepConv(kernel=[3, 3], stride=2, depth=128),
    DepthSepConv(kernel=[3, 3], stride=1, depth=128),
    DepthSepConv(kernel=[3, 3], stride=2, depth=256),
    DepthSepConv(kernel=[3, 3], stride=1, depth=256),
    DepthSepConv(kernel=[3, 3], stride=2, depth=512),
    DepthSepConv(kernel=[3, 3], stride=1, depth=512),
    DepthSepConv(kernel=[3, 3], stride=1, depth=512),
    DepthSepConv(kernel=[3, 3], stride=1, depth=512),
    DepthSepConv(kernel=[3, 3], stride=1, depth=512),
    DepthSepConv(kernel=[3, 3], stride=1, depth=512),
    DepthSepConv(kernel=[3, 3], stride=2, depth=1024),
    DepthSepConv(kernel=[3, 3], stride=1, depth=1024)
]

input_size = 160
inputdepth = 3
conv_defs = _CONV_DEFS
sumcost = 0
for i, conv_def in enumerate(conv_defs):
    stride = conv_def.stride
    kernel = conv_def.kernel
    outdepth = conv_def.depth
    output_size = round((input_size - int(kernel[0] / 2) * 2) / stride)
    if isinstance(conv_def, Conv):
        sumcost += output_size * output_size * kernel[0] * kernel[0] * inputdepth * outdepth
    if isinstance(conv_def, DepthSepConv):
        sumcost += output_size * output_size * kernel[0] * kernel[0] * inputdepth * outdepth
    inputdepth = outdepth
    input_size = output_size
print("src conv:    ", sumcost)

input_size = 160
inputdepth = 3
conv_defs = _CONV_DEFS
sumcost1 = 0
for i, conv_def in enumerate(conv_defs):
    stride = conv_def.stride
    kernel = conv_def.kernel
    outdepth = conv_def.depth
    output_size = round((input_size - int(kernel[0] / 2) * 2) / stride)
    if isinstance(conv_def, Conv):
        sumcost1 += output_size * output_size * kernel[0] * kernel[0] * inputdepth * outdepth
    if isinstance(conv_def, DepthSepConv):
        #sumcost += output_size * output_size * kernel[0] * kernel[0] * inputdepth * outdepth
        sumcost1 += output_size * output_size *(inputdepth * kernel[0] * kernel[0]  + inputdepth * outdepth * 1 * 1)
    inputdepth = outdepth
    input_size = output_size
print("DepthSepConv:", sumcost1)
print("compare:", sumcost1 / sumcost)

src conv:            1045417824 DepthSepConv:   126373376 compare: 0.12088312739538674