注意力機制之Attention Augmented Convolutional Networks

原始連結:https://www.yuque.com/lart/papers/aaconv

核心內容

We propose to augment convolutional operators with this self-attention mechanism by concatenating convolutional feature maps with a set of feature maps produced via self-attention.

主要工作

首先了解卷積操作本身兩點特性:

儘管這些屬性被證明了是設計在影象上操作的模型時至關重要的歸納偏置(inductive biase). 但是卷積的區域性性質(the local nature of the convolutional kernel)阻礙了其捕獲全域性的上下文資訊(global context), 而這些資訊對於影象識別是很必要的. 這是卷積的重要的弱點. (convolution operator is limited by its locality and lack of understandingof global contexts)

而在捕獲長距離互動關係(long range interaction)上, 最近的Self-attention表現的很不錯(has emerged as a recent advance). 自注意力背後的關鍵思想是生成從隱藏單元計算的值的加權平均值. 不同於卷積操作或者池化操作, 這些權重是動態的根據輸入特徵, 通過隱藏單元之間的相似性函式產生的(produced dynamically via a similarity function between hidden units). 因此輸入訊號之間的互動依賴於訊號本身, 而不是像在卷積中, 被預先由他們的相對位置而決定.

所以本文嘗試將自注意力計算應用到卷積操作中, 來實現長距離互動. 在判別性視覺任務(discriminative visual tasks)中, 考慮使用自注意力替換普通的卷積. 引入a novel two-dimensional relative self-attention mechanism, 其在注入(being infused with)相對位置資訊的同時可以保持translation equivariance, 使其非常適合影象.

在取代卷積作為獨立計算單元方面被證明是有競爭力的. 但是需要注意的是, 在控制實驗中發現, 將自注意力和卷積組合起來的情況可以獲得最好的結果. 因此並沒有完全拋棄卷積, 而是提出使用self-attention mechanism來增強卷積(augment convolutions), 即將強調區域性性的卷積特徵圖和基於self-attention產生的能夠建模更長距離依賴(capable of modeling longer range dependencies)的特徵圖拼接來獲得最終結果.

在多個實驗中, 注意力增強卷積都實現了一致的提升, 另外對於完全的自注意模型(不用卷積那部分), 這可以看作是注意力增強模型的一種特殊情況, 在ImageNet上僅比它們的完全卷積結構略差, 這表明自注意機制是一種用於影象分類的強大獨立的計算原語(a powerful standalone computational primitive).

關於primitive這個概念, 找到了一段解釋: 大意是指整個系統中最基本的概念.

https://stackoverflow.com/a/8022435

For me, it means something that cannot be decomposed (people use also the atomic word sometimes in that sense, but atomic is often also used for explanation on concurrency or parallelism with a different meaning).​

For instance, on Unix (or Linux) the system calls, as seen by the application are primitive or atomic, they either happen or not (sometimes, they got interrupted and give an EINTR or ERESTART error).

And inside an interpreter, or even in the formal specification, of a language, the primitive are those operations which you cannot define, and which the interpreter deals with specially. Very often, cons is a primitive operation for Lisp dialects.

這裡提到了其他的一些visual tasks中的注意力的工作:

相對於現有的方法, 這裡要提出的結構不依賴於對應的(counterparts)完全卷積模型的預訓練, 而是整個網路都使用了self-attention mechanism. 另外multi-head attention的使用使得模型同時關注空間子空間和特徵子空間. (多頭注意力就是將特徵劃沿著通道劃分為不同的組, 不同組內進行單獨的變換, 可以獲得更加多樣化的特徵表達)

另外, 為了增強影象上的自注意力的表達能力, 這裡擴充套件[Selfattention with relative position representations,  Music transformer]中的相對自注意力到二維形式, 這使得可以以有原則(in a principled way)地模擬平移等變性(translation equivariance).

這樣的結構可以直接產生額外的特徵圖, 而不是通過加法(可能是乘法)[Non-local neural networks,  Self-attention generative adversarial networks]或門控[Squeeze-and-excitation networks, Gather-excite: Exploiting feature context in convolutional neural networks, Bam: bottleneck attention module, Cbam: Convolutional block attention module]重新校準卷積特徵. 這一特性允許靈活地調整注意力通道的比例, 考慮從完全卷積到完全注意模型的一系列架構(a spectrum of architectures, ranging from fully convolutional to fully attentional models).

主要結構

  • H, W, Fin: 輸入特徵圖的height, weight, 通道數
  • Nh, dv, dk:heads的數量, values的深度(也就是特徵圖通道數), queries和keys的深度(這幾個引數都是MHA, multi-head attention的一些引數), 這裡有要求, dv和dk必須可以被Nh整除, 這裡使用dhv和dhk來作為每個head中值的深度和查詢/鍵的深度

影象資料多頭注意力的計算

單頭的計算形式

多頭是由單頭拼接而成

  1. in_tensor\((H,W,F_{in})\) =(flatten)=> X\((HW,F_{in})\)(We omit the batch dimension for simplicity.)
  2. 按照transformer結構結算多頭注意力
    1. 對於head h對應的自注意力結果為式子1所示, 這裡的\(W_q\)/\(W_k\)/\(W_v\)分別形狀為\((F_{in}, d^h_q)/(F_{in}, d^h_k)/(F_{in}, d^h_v)\), 分別用於對映輸入X到查詢\(Q=XW_q\) 、鍵\(K=XW_k\) 和值\(V=XW_v\) , 分別的形狀為\((HW, d^h_q)/(HW, d^h_k)/(HW, d^h_v)\)​
    2. 所有head的輸出拼接到一起, 然後按照式子2進行處理, 這裡的\(W^O \in \mathbb{R}^{d_v \times d_v}\)(可以知道, 這裡的\(N_h\)個\(O\)的拼接, 實際上深度為\(d_v\), 也就是\(d_v=N_h \times d^h_v\)), 這裡MHA計算後會調整形狀為\((H, W, d_v)\)來匹配原始的空間維度
    3. multi-head attention
      1. 計算複雜度:\(O((HW)^2d_k)\)(這裡只需要考慮大頭\((XW_q)(XW_k)^T\)的計算)
      2. 空間複雜度:\(O((HW)^2N_h)\)(這裡包含了Nh個頭的結果)

二維位置嵌入Two-dimensional Positional Embeddings

這裡的"二維"實際上是相對於原始針對語言的一維資訊的結構而言, 這裡輸入的是二維影象資料.

由於沒有顯式的位置資訊的利用, 所以自注意力滿足交換律:\(MHA(\pi(X))=\pi(MHA(X))\), 這裡的\(\pi\)表示對於畫素位置的任意置換. 這反映出來self-attention具有 permutation equivariant. 這樣的性質使得對於模擬高度結構化的資料(例如影象)而言, 不是很有效.

多個使用顯式的空間資訊來增強啟用圖的位置編碼已經被提出來處理相關的問題:

  1. Image Transformer extends the sinusoidal waves first introduced in the original Transformer to 2 dimensional inputs.
  2. CoordConv concatenates positional channels to an activation map.

在文章的實驗中發現, 在影象分類和目標檢測上, 這些編碼方法並不好用, 作者們將其歸因於雖然這些策略可以打破置換等變性, 但是卻不能保證影象任務需要的平移等變性(permutation equivariant(置換等變性), translation equivariance(平移等變性)). 為此, 這裡擴充套件了現有的相對位置編碼[Self attention with relative position representations]到二維上, 並且基於Music Transformer提出一個記憶體有效的實現.

相對位置嵌入Relative positional embeddings

Introduced in [Self attention with relative position representations] for the purpose of language modeling, relative self-attention augments self-attention with relative position encodings and enables translation equivariance while preventing permutation equivariance.

這裡通過獨立新增相對的寬和相對的高的資訊, 來實現二維相對自注意力.

對於畫素\(i=(i_x, i_y)\)關於畫素\(j=(j_x, j_y)\)的attention logit計算方式如下(The attention logit for how much pixel i attends to pixel j is computed as):

  • \(q_i\)表示 位置為\(i\) 的query vector, 也就是Q中的一個長為\(d^h_k\)的向量元素.
  • \(k_j\)表示 位置為\(j\) 的key vector, 也就是K中的一個長為\(d^h_k\)的向量元素.
  • \(r^W_{j_x-i_x}\)和\(r^H_{j_y-i_y}\)表示對於相對寬度\(j_x-i_x\)和相對高度\(j_y-i_y\)學習到的嵌入表示, 各自均為dhk長度的向量.
  • \(r\)對應的相對位置引數矩陣\(r^W\)和\(r^H\)分別是\((2W-1, d^h_k)\)和\((2H-1, d^h_k)\)大小的.

單個頭h的輸出變成了:

這裡的兩個\(S\)都是\(HW \times HW\)的矩陣, 表示沿著寬高維度的相對位置logits

因為考慮相對寬高資訊, 所以滿足\(S^{rel}_W[i, j]=S^{rel}_W[i, j+W]\),\(S^{rel}_H[i, j]=S^{rel}_H[i, j+H]\). 這樣就不需要為所有的(i, j)對計算logits了, 這裡可以按照這樣來理解(這是我自己的理解): 對於二維矩陣, 按照沿著行為W方向(橫向), 也即是x方向, 沿著列為H方向(縱向)即y向, 對於任意一點\(j\)和固定的點\(i\):

  • SW中有\((j_x-i_x)\%W=[(j+nW)_x-i_x]\%W\), 即按照行主序向後移動個位置, 仍位於同一列;
  • SH中有\((j_y-i_y)\%H=[(j+nH)_y-i_x]\%H\), 即按照列主序向後移動\(nH\)個位置, 依然在同一行.

這裡的相對注意力的形式實際上不同於原始參考論文Self attention with relative position representations中具有記憶體佔用為\(O((HW)^2d^h_k)\)(相對嵌入\(r_{ij} \in \mathbb{R}^{HW \times HW \times d^h_k}\))的設計, 而是基於MUSIC TRANSFORMER中提出的memory efficient relative masked attention algorithm的一種2D擴充套件, 擴充套件為了unmasked relative self-attention over 2 dimensional inputs上, 從而儲存消耗變成了\(O(HWd^h_k)\)(相對位置嵌入\(r_{ij}\)被拆分成兩個部分, 即\(r^H \in \mathbb{R}^{(2H-1) \times d^h_k}, r^W \in \mathbb{R}^{(2W-1 )\times d^h_k}\), 並且跨頭不跨層的形式進行共享). 對於每層, 實際上只需要新增額外的\((2(H + W) − 2)d^h_k\)個引數來建模沿著高和寬的相對距離即可.

Attention Augmented Convolution

文章提出的使用注意力增強的卷積主要的優勢:

  1. use an attention mechanism that can attend jointly to spatial and feature subspaces (each head corresponding to a feature subspace)
  2. introduce additional feature maps rather than refining them

AAConv的主要過程:

Similarly to the convolution, the proposed attention augmented convolution

  1. is equivariant to translation
  2. can readily operate on inputs of different spatial dimensions

接下來對標一般的卷積\((F_{out}, F_{in}, k, k)\)分析了AAConv的引數量:

  • 設定\(v=\frac{d_v}{F_{out}}\)作為MHA部分的總輸出通道數與總的AAConv輸出通道數的比值;
  • 設定\(\kappa = \frac{d_k}{F_{out}}\)作為MHA中Key的深度與總的AAConv輸出通道數的比值.
  • 使用\(1 \times 1\)卷積來線性變換得到Q\K\V, 所以有引數量\((d_v+d_k+d_q)F_{in} = (2d_k+d_v)F_{in}=(v+2\kappa)F_{out}F_{in}\)
  • 使用一個額外的\(1\times1\)卷積用於混合多個頭的貢獻(mix the contribution of different heads), 這部分引數量為\(d_vd_v=(vF_{out})^2\);
  • 除了注意力部分, 還有一部分標準卷積, 即前面式子中的Conv, 其引數量為:\(k^2(F_{out} - d_v)F_{in} = k^2(1 - v)F_{out}F_{in}\);
  • 所以, 忽略了相對位置嵌入和卷積偏置之後, 整體的結構的引數量約為:\(F_{in}F_{out}(2\kappa+v+v^2\frac{F_{out}}{F_{in}}+k^2-k^2v)=F_{in}F_{out}(2\kappa+v(1-k^2)+k^2+v^2\frac{F_{out}}{F_{in}})\)
  • 整體相對於卷積的引數的變化量為\(\Delta_{params}\sim F_{in}F_{out}(2\kappa+v(1-k^2)+v^2\frac{F_{out}}{F_{in}})\), 所以替換3x3卷積時, 會輕微減少引數量, 而替換1x1卷積時, 則會帶來輕微的增加.

Attention Augmented Convolutional Architectures

  • 所有實驗中, AAConv後都會跟著BN來放縮卷積層和注意力層特徵圖的共享.
  • 每個殘差塊使用一次AAConv.
  • 由於QK的結果具有較大的記憶體佔用, 所以是按照從深到淺的順序使用, 直到達到記憶體上限.
  • To reduce the memory footprint of augmented networks, we typically resort to a smaller batch size and sometimes additionally downsample the inputs to self-attention in the layers with the largest spatial dimensions where it is applied(這裡指的應該是在注意力計算前後分別下采樣和上取樣). Downsampling is performed by applying 3x3 average pooling with stride 2 while the following upsampling (requiredfor the concatenation) is obtained via bilinear interpolation.

實驗結果

位置編碼

  • the position-unaware version of self-attention (referred to as None),
  • a two-dimensional implementation of the sinusoidal positional waves (referred to as 2d Sine) as used in [32],
  • CoordConv [29] for which we concatenate (x, y, r) coordinate channels to the inputs of the attention function,
  • our proposed two-dimensional relative position encodings (referred to as Relative).

未來的探索

  • Several open questions from this work remain. In future work, we will focus on the fully attentional regime and explore how different attention mechanisms trade off computational efficiency versus representational power. For instance, identifying a local attention mechanism may result in an efficient and scalable computational mechanism that could prevent the need for downsampling with average pooling [Stand-aloneself-attention in vision models].
  • Additionally, it is plausible that architectural design choices that are well suited when exclusively relying on convolutions are suboptimal when using self-attention mechanisms. As such, it would be interesting to see if using Attention Augmentation as a primitive in automated architecture search procedures proves useful to find even better models than those previously found in image classification [55], object detection [12], image segmentation [6] and other domains [5, 1, 35, 8].
  • Finally, one can ask to which degree fully attentional models can replace convolutional networks for visual tasks.

程式碼示例

參照作者論文中的tensorflow實現, 我使用pytorch改了下.

import torch
from einops import rearrange
from torch import nn def rel_to_abs(x):
"""
Converts tensor from relative to aboslute indexing.
Details can be found at: https://www.yuque.com/lart/ugkv9f/oazsec :param x: B Nh L 2L-1
:return: B Nh L L
"""
B, Nh, L, _ = x.shape # Pad to shift from relative to absolute indexing.
col_pad = torch.zeros(B, Nh, L, 1)
x = torch.cat([x, col_pad], dim=3) flat_x = x.reshape(B, Nh, L * 2 * L) flat_pad = torch.zeros(B, Nh, L - 1)
flat_x = torch.cat([flat_x, flat_pad], dim=2) # Reshape and slice out the padded elements.
final_x = flat_x.reshape(B, Nh, L + 1, 2 * L - 1)
final_x = final_x[:, :, :L, L - 1:]
return final_x def relative_logits_1d(x, rel_k):
"""
Compute relative logits along one dimenion. :param x: B Nh Hd L
:param rel_k: 2L-1 Hd
"""
rel_logits = torch.einsum("bndl, rd -> bnlr", x, rel_k)
rel_logits = rel_to_abs(rel_logits) # B Nh L 2L-1 -> B Nh L L
return rel_logits class RelativePosEmbedding(nn.Module):
"""
Compute relative_logits. For ease, we 1) transpose height and width, 2) repeat the above steps and 3) transpose to eventually
put the logits in their right positions.
""" def __init__(self, h, w, dim):
super(RelativePosEmbedding, self).__init__()
self.h = h
self.w = w
self.rel_emb_w = torch.randn(2 * w - 1, dim)
nn.init.normal_(self.rel_emb_w, dim ** -0.5)
self.rel_emb_h = torch.randn(2 * h - 1, dim)
nn.init.normal_(self.rel_emb_h, dim ** -0.5) def forward(self, x):
"""
:param x: B Nh Hd HW
:return: B Nh HW HW
"""
Nh = x.shape[1]
# Relative logits in width dimension first.
rel_logits_w = relative_logits_1d(
rearrange(x, "b nh hd (h w) -> b (nh h) hd w", h=self.h, w=self.w), self.rel_emb_w
)
rel_logits_w = rearrange(rel_logits_w, "b (nh h) w0 w1 -> b nh h () w0 w1", nh=Nh)
# Relative logits in height dimension next.
rel_logits_h = relative_logits_1d(
rearrange(x, "b nh hd (h w) -> b (nh w) hd h", h=self.h, w=self.w), self.rel_emb_h
)
rel_logits_h = rearrange(rel_logits_h, "b (nh w) h0 h1 -> b nh h0 h1 w ()", nh=Nh)
return rearrange(rel_logits_h + rel_logits_w, "b nh h0 h1 w0 w1 -> b nh (h0 w0) (h1 w1)") class AbsolutePosEmbedding(nn.Module):
"""
Given query q of shape [batch heads tokens dim] we multiply
q by all the flattened absolute differences between tokens.
Learned embedding representations are shared across heads
""" def __init__(self, h, w, dim):
super().__init__()
scale = dim ** -0.5
self.abs_pos_emb = nn.Parameter(torch.randn(h * w, dim) * scale)
nn.init.normal_(self.abs_pos_emb, scale) def forward(self, x):
"""
:param x: B Nh Hd HW
:return: B Nh HW HW
"""
return torch.einsum("bndx, yd -> bhxy", x, self.abs_pos_emb) class SelfAttention2D(nn.Module):
def __init__(self, in_dim, key_dim, value_dim, nh, hw, pos_mode="relative"):
super(SelfAttention2D, self).__init__()
self.dkh = key_dim // nh
self.dvh = value_dim // nh
self.nh = nh
self.key_dim = key_dim
self.value_dim = value_dim
self.kqv_proj = nn.Conv2d(in_dim, 2 * key_dim + value_dim, 1)
self.out_proj = nn.Conv2d(value_dim, value_dim, 1)
if pos_mode == "relative":
self.position_embedding = RelativePosEmbedding(h=hw[0], w=hw[1], dim=self.dkh)
elif pos_mode == "absolute":
self.position_embedding = AbsolutePosEmbedding(h=hw[0], w=hw[1], dim=self.dkh)
else:
self.position_embedding = nn.Identity() def split_heads_and_flatten(self, _x):
return rearrange(_x, "b (nh hd) h w -> b nh hd (h w)", nh=self.nh) def forward(self, x):
"""
:param x: B C H W
""" # Compute q, k, v
k, q, v = self.kqv_proj(x).split([self.key_dim, self.key_dim, self.value_dim], dim=1)
q = q * self.dkh ** -0.5 # scaled dot-product # After splitting, shape is [B, Nh, dkh or dvh, HW]
q, k, v = map(self.split_heads_and_flatten, (q, k, v)) # [B, Nh, HW, HW]
logits = torch.einsum("bndx, bndy -> bnxy", q, k)
logits += self.position_embedding(q)
weights = logits.softmax(-1)
attn_out = torch.einsum("bnxy, bndy -> bndx", weights, v)
attn_out = rearrange(attn_out, "b nd hd (h w) -> b (nd hd) h w", h=x.shape[2], w=x.shape[3]) # Project heads
attn_out = self.out_proj(attn_out)
return attn_out class AugmentedConv2d(nn.Module):
def __init__(self, in_dim, out_dim, kernel_size, key_dim, value_dim, num_heads, hw, pos_mode):
super(AugmentedConv2d, self).__init__()
self.std_conv = nn.Conv2d(in_dim, out_dim - value_dim, kernel_size, padding=kernel_size // 2)
self.attention = SelfAttention2D(
in_dim, key_dim=key_dim, value_dim=value_dim, nh=num_heads, hw=hw, pos_mode=pos_mode
) def forward(self, x):
conv_out = self.std_conv(x)
attn_out = self.attention(x)
return torch.cat([conv_out, attn_out], dim=1) if __name__ == "__main__":
m = AugmentedConv2d(
in_dim=4, out_dim=64, kernel_size=3, key_dim=32, value_dim=48, num_heads=2, hw=(10, 10), pos_mode="relative"
)
print(m(torch.randn(4, 4, 10, 10)).shape)

一些疑惑

  • permutation equivariance(置換等變性), translation equivariance(平移等變性)二者的差異是什麼?

補充知識

對於self-attention包含三個輸入, query Q/key K/value V, 三者具體表示的含義是什麼呢? 以下內容摘自https://www.cnblogs.com/rosyYY/p/10115424.html:

  1. Q、K、V中包含的都是原始資料的嵌入表示
  2. Q為什麼叫query?
    1. 是因為每次需要拿一個嵌入表示去"查詢"其和任意的嵌入表示之間的match程度, 也就是attention大小
  3. K和V表示鍵值, 關於這裡的解釋, 各處都語焉不詳, 在 從Seq2seq到Attention模型到Self Attention(二) - 量化投資機器學習的文章 - 知乎 https://zhuanlan.zhihu.com/p/47470866 中有處提到:"key、value的起源論文 Key-Value Memory Networks for Directly Reading Documents. 在NLP的領域中, Key, Value通常就是指向同一個文字隱向量(word embedding vector)". 暫且做過多解釋.

相關連結