Self-Attention GAN 中的 self attention 機制
首先分別貼出 Self-Attention GAN 的文章和程式碼連結。
文章 arxiv.org pytorch 版本程式碼 github.com
Self Attention GAN 用到了很多新的技術。
最大的亮點當然是self-attention 機制 ,該機制是Non-local Neural Networks 這篇文章提出的。其作用是能夠更好地學習到全域性特徵之間的依賴關係。因為傳統的 GAN 模型很容易學習到紋理特徵:如皮毛,天空,草地等,不容易學習到特定的結構和幾何特徵,例如狗有四條腿,既不能多也不能少。
除此之外,文章還用到了Spectral Normalization for GANs 提出的譜歸一化 。譜歸一化的解釋見本人這篇知乎文章:Spectral Normalization 譜歸一化 。但是,該文程式碼中的譜歸一化和原始的譜歸一化運用方式略有差別:
- 原始的譜歸一化基於 W-GAN 的理論,只用在 Discriminator 中,用以約束 Discriminator 函式為 1-Lipschitz 連續。而在 Self-Attention GAN 中,Spectral Normalization 同時出現在了 Discriminator 和 Generator 中,用於使梯度更穩定。除了生成器和判別器的最後一層外,每個 卷積/反捲積 單元都會上一個 SpectralNorm。
- 當把譜歸一化用在 Generator 上時,同時還保留了 BatchNorm。Discriminator 上則沒有 BatchNorm,只有 SpectralNorm。
- 譜歸一化用在 Discriminator 上時最後一層不加 Spectral Norm。
最後,self-attention GAN 還用到了cGANs With Projection Discriminator 提出的conditional normalization 和projection in the discriminator 。這兩個技術我還沒有來得及看,而且 pytorch 版本的 self-attention gan 程式碼中也沒有實現,就先不管它們了。本文主要說的是 self-attention 這部分內容。

Self-Attention
在卷積神經網路中,每個卷積核的尺寸都是很有限的(基本上不會大於5),因此每次卷積操作只能覆蓋畫素點周圍很小一塊鄰域。對於距離較遠的特徵,例如狗有四條腿這類特徵,就不容易捕獲到了(也不是完全捕獲不到,因為多層的卷積、池化操作會把 feature map 的高和寬變得越來越小,越靠後的層,其卷積核覆蓋的區域映射回原圖對應的面積越大。但總而言之,畢竟還得需要經過多層對映,不夠直接)。Self-Attention 通過直接計算影象中任意兩個畫素點之間的關係,一步到位地獲取影象的全域性幾何特徵。
論文中的公式不夠直觀,我們直接看文章開頭的 pytorch 的程式碼,核心部分為 sagan_models.py:
class Self_Attn(nn.Module): """ Self attention Layer""" def __init__(self,in_dim,activation): super(Self_Attn,self).__init__() self.chanel_in = in_dim self.activation = activation self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1) self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1) self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1) self.gamma = nn.Parameter(torch.zeros(1)) self.softmax= nn.Softmax(dim=-1) # def forward(self,x): """ inputs : x : input feature maps( B X C X W X H) returns : out : self attention value + input feature attention: B X N X N (N is Width*Height) """ m_batchsize,C,width ,height = x.size() proj_query= self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1) # B X CX(N) proj_key =self.key_conv(x).view(m_batchsize,-1,width*height) # B X C x (*W*H) energy =torch.bmm(proj_query,proj_key) # transpose check attention = self.softmax(energy) # BX (N) X (N) proj_value = self.value_conv(x).view(m_batchsize,-1,width*height) # B X C X N out = torch.bmm(proj_value,attention.permute(0,2,1) ) out = out.view(m_batchsize,C,width,height) out = self.gamma*out + x return out,attention
建構函式中定義了三個
的卷積核,分別被命名為
query_conv
,key_conv
和value_conv
。為啥命名為這三個名字呢?這和作者給它們賦予的含義有關。query 意為查詢,我們希望輸入一個畫素點,查詢(計算)到 feature map 上所有畫素點對這一點的影響。而 key 代表字典中的鍵,相當於所查詢的資料庫。query 和 key 都是輸入的 feature map,可以看成把 feature map 複製了兩份,一份作為 query 一份作為 key。
需要用一個什麼樣的函式,才能針對 query 的 feature map 中的某一個位置,計算出 key 的 feature map 中所有位置對它的影響呢?作者認為這個函式應該是可以通過“學習”得到的。那麼,自然而然就想到要對這兩個 feature map 分別做卷積核為
的卷積了,因為卷積核的權重是可以學習得到的。
至於value_conv
,可以看成對原 feature map 多加了一層卷積對映,這樣可以學習到的引數就更多了,否則qurey_conv
和key_conv
的引數太少,按程式碼中只有in_dims
in_dims//8
個。
接下來逐行研究 forward 函式:
proj_query= self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1)
這行程式碼先對輸入的 feature map 卷積了一次,相當於對 query feature map 做了一次投影,所以叫做proj_query
。由於是
的卷積,所以不改變 feature map 的長和寬。feature map 的每個通道為如 (1) 所示的矩陣,矩陣共有 N 個元素(畫素)。
然後重新改變了輸出的維度,變成:
(m_batchsize,-1,width*height)
batch size 保持不變,width 和 height 融合到一起,把如(1)所示二維的 feature map 每個 channel 拉成一個長度為 N 的向量。因此,如果m_batchsize
取 1,即單獨觀察一個樣本,該操作的結果是得到一個矩陣,矩陣的的行數為query_conv
卷積輸出的 channel 的數目
(
in_dim//8
),列數為 feature map 畫素數
。然後作者又通過
.permute(0, 2, 1)
轉置了矩陣,矩陣的行數變成了 feature map 的畫素數
,列數變成了通道數
。因此矩陣維度為
。該矩陣每行代表一個畫素位置上所有通道的值,每列代表某個通道中所有的畫素值。

proj_key =self.key_conv(x).view(m_batchsize,-1,width*height)
這行程式碼和上一行類似,只不過取消了轉置操作。得到的矩陣行數為通道數
,列數為畫素數
,即矩陣維度為
。該矩陣每行代表一個通道中所有的畫素值,每列代表一個畫素位置上所有通道的值。

energy =torch.bmm(proj_query,proj_key)
這行程式碼中,torch.bmm
的意思是 batch matrix multiplication。就是說把相同 batchsize 的兩組 matrix 一一對應地做矩陣乘法,最後得到同樣 batchsize 的新矩陣。若 batchsize=1,就是普通的矩陣乘法。已知proj_query
維度是
,
proj_key
的維度是
,因此
energy
的維度是
:

energy
是 attention 的核心,其中第 i 行 j 列的元素,是由proj_query
第 i 行,和proj_key
第 j 列通過向量點乘得到的。而proj_query
第 i 行表示的是 feature map 上第 i 個畫素位置上所有通道的值,也就是第 i 個畫素位置的所有資訊,而proj_key
第 j 列表示的是 feature map 上第 j 個畫素位置上的所有通道值,也就是第 j 個畫素位置的所有資訊。這倆相乘,可以看成是第 j 個畫素對第 i 個畫素的影響。即,energy 中第 i 行 j 列的元素值,表示第 j 個畫素點對第 i 個畫素點的影響。
attention = self.softmax(energy)
這裡 sofmax 是建構函式中定義的,為按“行”歸一化。這個操作之後的矩陣,各行元素之和為1。這也比較好理解,因為 energy 中第 i 行元素,代表 feature map 中所有位置的畫素對第 i 個畫素的影響,而這個影響被解釋為權重,故加起來應該是 1,故應對其按行歸一化。attention
的維度也是
。
proj_value = self.value_conv(x).view(m_batchsize,-1,width*height)
上面的程式碼中,先對原 feature map 作一次卷積對映,然後把得到的新 feature map 改變形狀,維度變為
,其中 C 為通道數(注意和上面計算
proj_query
proj_key
的 C 不同,上面的 C 為 feature map 通道數的 1/8,這裡的 C 與 feature map 通道數相同),N 為 feature map 的畫素數。

out = torch.bmm(proj_value,attention.permute(0,2,1) ) out = out.view(m_batchsize,C,width,height)
然後,再把proj_value
(
)矩陣同 attention 矩陣的轉置(
)相乘,得到
out
(
)。之所以轉置,是因為
attention
中每行的和為1,其意義是權重,需要轉置後變為每列的和為1,施加於proj_value
的行上,作為該行的加權平均。proj_value
第 i 行代表第 i 個通道所有的畫素值,attention
第 j 列,代表所有畫素施加到第 j 個畫素的影響。因此,out
中第 i 行包含了輸出的第 i 個通道中的所有畫素,第 j 列表示所有畫素中的第 j 個畫素,合起來也就是:out
中的第 i 行第 j 列的元素,表示被 attention 加權之後的 feature map 的第 i 個通道的第 j 個畫素的畫素值。再改變一下形狀,out
就恢復了 channel×width×height的結構。

out = self.gamma*out + x
最後一行程式碼,借鑑了殘差神經網路(residual neural networks)的操作,gamma
是一個引數,表示整體施加了 attention 之後的 feature map 的權重,需要通過反向傳播更新。而x
就是輸入的 feature map。在初始階段,gamma
為0,該 attention 模組直接返回輸入的 feature map,之後隨著學習,該 attention 模組逐漸學習到了將 attention 加權過的 feature map 加在原始的 feature map 上,從而強調了需要施加註意力的部分 feature map。
總結
可以把 self attention 看成是 feature map 和它自身的轉置相乘,讓任意兩個位置的畫素直接發生關係,這樣就可以學習到任意兩個畫素之間的依賴關係,從而得到全域性特徵了。看論文時會被它複雜的符號迷惑,但是一看程式碼就發現其實是很naive的操作。