Transformer模型的PyTorch實現
本文由 羅周楊 原創,轉載請註明作者和出處。未經授權,不得用於商業用途。
Google 2017年的論文 Attention is all you need 闡釋了什麼叫做大道至簡!該論文提出了 Transformer 模型,完全基於 Attention mechanism ,拋棄了傳統的 RNN 和 CNN 。
我們根據論文的結構圖,一步一步使用 PyTorch 實現這個 Transformer 模型。
Transformer架構
首先看一下transformer的結構圖:

解釋一下這個結構圖。首先, Transformer 模型也是使用經典的 encoer-decoder 架構,由encoder和decoder兩部分組成。
上圖的左半邊用 Nx
框出來的,就是我們的encoder的一層。encoder一共有6層這樣的結構。
上圖的右半邊用 Nx
框出來的,就是我們的decoder的一層。decoder一共有6層這樣的結構。
輸入序列經過 word embedding 和 positional encoding 相加後,輸入到encoder。
輸出序列經過 word embedding 和 positional encoding 相加後,輸入到decoder。
最後,decoder輸出的結果,經過一個線性層,然後計算softmax。
word embedding和 positional encoding 我後面會解釋。我們首先詳細地分析一下encoder和decoder的每一層是怎麼樣的。
Encoder
encoder由6層相同的層組成,每一層分別由兩部分組成:
- 第一部分是一個 multi-head self-attention mechanism
- 第二部分是一個 position-wise feed-forward network ,是一個全連線層
兩個部分,都有一個 殘差連線(residual connection) ,然後接著一個 Layer Normalization 。
如果你是一個新手,你可能會問:
- multi-head self-attention 是什麼呢?
- 參差結構是什麼呢?
- Layer Normalization又是什麼?
這些問題我們在後面會一一解答。
Decoder
和encoder類似,decoder由6個相同的層組成,每一個層包括以下3個部分:
- 第一個部分是 multi-head self-attention mechanism
- 第二部分是 multi-head context-attention mechanism
- 第三部分是一個 position-wise feed-forward network
還是和encoder類似,上面三個部分的每一個部分,都有一個 殘差連線 ,後接一個 Layer Normalization 。
但是,decoder出現了一個新的東西 multi-head context-attention mechanism 。這個東西其實也不復雜,理解了 multi-head self-attention 你就可以理解 multi-head context-attention 。這個我們後面會講解。
Attention機制
在講清楚各種attention之前,我們得先把attention機制說清楚。
通俗來說, attention 是指,對於某個時刻的輸出 y
,它在輸入 x
上各個部分的注意力。這個注意力實際上可以理解為 權重 。
attention機制也可以分成很多種。 Attention? Attention! 一問有一張比較全面的表格:

上面第一種 additive attention 你可能聽過。以前我們的seq2seq模型裡面,使用attention機制,這種**加性注意力(additive attention)**用的很多。Google的專案 tensorflow/nmt 裡面使用的attention就是這種。
為什麼這種attention叫做 additive attention 呢?很簡單,對於輸入序列隱狀態 和輸出序列的隱狀態 ,它的處理方式很簡單,直接 合併 ,變成
但是我們的transformer模型使用的不是這種attention機制,使用的是另一種,叫做 乘性注意力(multiplicative attention) 。
那麼這種 乘性注意力機制 是怎麼樣的呢?從上表中的公式也可以看出來: 兩個隱狀態進行點積 !
Self-attention是什麼?
到這裡就可以解釋什麼是 self-attention 了。
上面我們說attention機制的時候,都會說到兩個隱狀態,分別是 和 ,前者是輸入序列第i個位置產生的隱狀態,後者是輸出序列在第t個位置產生的隱狀態。
所謂 self-attention 實際上就是, 輸出序列 就是 輸入序列 !因此,計算自己的attention得分,就叫做 self-attention !
Context-attention是什麼?
知道了 self-attention ,那你肯定猜到了 context-attention 是什麼了: 它是encoder和decoder之間的attention !所以,你也可以稱之為 encoder-decoder attention !
context-attention一詞並不是本人原創,有些文章或者程式碼會這樣描述,我覺得挺形象的,所以在此沿用這個稱呼。其他文章可能會有其他名稱,但是不要緊,我們抓住了重點即可,那就是 兩個不同序列之間的attention ,與 self-attention 相區別。
不管是 self-attention 還是 context-attention ,它們計算attention分數的時候,可以選擇很多方式,比如上面表中提到的:
- additive attention
- local-base
- general
- dot-product
- scaled dot-product
那麼我們的Transformer模型,採用的是哪種呢?答案是: scaled dot-product attention 。
Scaled dot-product attention是什麼?
論文 Attention is all you need 裡面對於attention機制的描述是這樣的:
An attention function can be described as a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors. The output is computed as a weighted sum of the values, where the weight assigned to each value is computed by a compatibility of the query with the corresponding key.
這句話描述得很清楚了。翻譯過來就是: 通過確定Q和K之間的相似程度來選擇V !
用公式來描述更加清晰:
scaled dot-product attention和 dot-product attention 唯一的區別就是, scaled dot-product attention 有一個縮放因子 。
上面公式中的 表示的是K的維度,在論文裡面,預設是 64
。
那麼為什麼需要加上這個縮放因子呢?論文裡給出瞭解釋:對於 很大的時候,點積得到的結果維度很大,使得結果處於softmax函式梯度很小的區域。
我們知道,梯度很小的情況,這對反向傳播不利。為了克服這個負面影響,除以一個縮放因子,可以一定程度上減緩這種情況。
為什麼是 呢?論文沒有進一步說明。個人覺得你可以使用其他縮放因子,看看模型效果有沒有提升。
論文也提供了一張很清晰的結構圖,供大家參考:

首先說明一下我們的K、Q、V是什麼:
- 在encoder的self-attention中,Q、K、V都來自同一個地方(相等),他們是上一層encoder的輸出。對於第一層encoder,它們就是word embedding和positional encoding相加得到的輸入。
- 在decoder的self-attention中,Q、K、V都來自於同一個地方(相等),它們是上一層decoder的輸出。對於第一層decoder,它們就是word embedding和positional encoding相加得到的輸入。但是對於decoder,我們不希望它能獲得下一個time step(即將來的資訊),因此我們需要進行 sequence masking 。
- 在encoder-decoder attention中,Q來自於decoder的上一層的輸出,K和V來自於encoder的輸出,K和V是一樣的。
- Q、K、V三者的維度一樣,即 。
上面scaled dot-product attention和decoder的self-attention都出現了 masking 這樣一個東西。那麼這個mask到底是什麼呢?這兩處的mask操作是一樣的嗎?這個問題在後面會有詳細解釋。
Scaled dot-product attention的實現
咱們先把scaled dot-product attention實現了吧。程式碼如下:
import torch import torch.nn as nn class ScaledDotProductAttention(nn.Module): """Scaled dot-product attention mechanism.""" def __init__(self, attention_dropout=0.0): super(ScaledDotProductAttention, self).__init__() self.dropout = nn.Dropout(attention_dropout) self.softmax = nn.Softmax(dim=2) def forward(self, q, k, v, scale=None, attn_mask=None): """前向傳播. Args: q: Queries張量,形狀為[B, L_q, D_q] k: Keys張量,形狀為[B, L_k, D_k] v: Values張量,形狀為[B, L_v, D_v],一般來說就是k scale: 縮放因子,一個浮點標量 attn_mask: Masking張量,形狀為[B, L_q, L_k] Returns: 上下文張量和attetention張量 """ attention = torch.bmm(q, k.transpose(1, 2)) if scale: attention = attention * scale if attn_mask: # 給需要mask的地方設定一個負無窮 attention = attention.masked_fill_(attn_mask, -np.inf) # 計算softmax attention = self.softmax(attention) # 新增dropout attention = self.dropout(attention) # 和V做點積 context = torch.bmm(attention, v) return context, attention 複製程式碼
Multi-head attention又是什麼呢?
理解了Scaled dot-product attention,Multi-head attention也很簡單了。論文提到,他們發現將Q、K、V通過一個線性對映之後,分成 份,對每一份進行 scaled dot-product attention 效果更好。然後,把各個部分的結果合併起來,再次經過線性對映,得到最終的輸出。這就是所謂的 multi-head attention 。上面的超引數 就是 heads 數量。論文預設是 8
。
下面是multi-head attention的結構圖:

值得注意的是,上面所說的 分成 份 是在 維度上面進行切分的。因此,進入到scaled dot-product attention的 實際上等於未進入之前的 。
Multi-head attention允許模型加入不同位置的表示子空間的資訊。
Multi-head attention的公式如下:
其中,
論文裡面, , 。所以在scaled dot-product attention裡面的
Multi-head attention的實現
相信大家已經理清楚了multi-head attention,那麼我們來實現它吧。程式碼如下:
import torch import torch.nn as nn class MultiHeadAttention(nn.Module): def __init__(self, model_dim=512, num_heads=8, dropout=0.0): super(MultiHeadAttention, self).__init__() self.dim_per_head = model_dim // num_heads self.num_heads = num_heads self.linear_k = nn.Linear(model_dim, self.dim_per_head * num_heads) self.linear_v = nn.Linear(model_dim, self.dim_per_head * num_heads) self.linear_q = nn.Linear(model_dim, self.dim_per_head * num_heads) self.dot_product_attention = ScaledDotProductAttention(dropout) self.linear_final = nn.Linear(model_dim, model_dim) self.dropout = nn.Dropout(dropout) # multi-head attention之後需要做layer norm self.layer_norm = nn.LayerNorm(model_dim) def forward(self, key, value, query, attn_mask=None): # 殘差連線 residual = query dim_per_head = self.dim_per_head num_heads = self.num_heads batch_size = key.size(0) # linear projection key = self.linear_k(key) value = self.linear_v(value) query = self.linear_q(query) # split by heads key = key.view(batch_size * num_heads, -1, dim_per_head) value = value.view(batch_size * num_heads, -1, dim_per_head) query = query.view(batch_size * num_heads, -1, dim_per_head) if attn_mask: attn_mask = attn_mask.repeat(num_heads, 1, 1) # scaled dot product attention scale = (key.size(-1) // num_heads) ** -0.5 context, attention = self.dot_product_attention( query, key, value, scale, attn_mask) # concat heads context = context.view(batch_size, -1, dim_per_head * num_heads) # final linear projection output = self.linear_final(context) # dropout output = self.dropout(output) # add residual and norm layer output = self.layer_norm(residual + output) return output, attention 複製程式碼
上面的程式碼終於出現了 Residual connection 和 Layer normalization 。我們現在來解釋它們。
Residual connection是什麼?
殘差連線其實很簡單!給你看一張示意圖你就明白了:

假設網路中某個層對輸入 x
作用後的輸出是 ,那麼增加 residual connection 之後,就變成了:
這個 +x
操作就是一個 shortcut 。
那麼 殘差結構 有什麼好處呢?顯而易見:因為增加了一項 ,那麼該層網路對x求偏導的時候,多了一個常數項 !所以在反向傳播過程中,梯度連乘,也不會造成 梯度消失 !
所以,程式碼實現residual connection很非常簡單:
def residual(sublayer_fn,x): return sublayer_fn(x)+x 複製程式碼
文章開始的transformer架構圖中的 Add & Norm
中的 Add
也就是指的這個 shortcut 。
至此, residual connection 的問題理清楚了。更多關於殘差網路的介紹可以看文末的參考文獻。
Layer normalization是什麼?
GRADIENTS, BATCH NORMALIZATION AND LAYER NORMALIZATION 一文對normalization有很好的解釋:
Normalization有很多種,但是它們都有一個共同的目的,那就是把輸入轉化成均值為0方差為1的資料。我們在把資料送入啟用函式之前進行normalization(歸一化),因為我們不希望輸入資料落在啟用函式的飽和區。
說到normalization,那就肯定得提到 Batch Normalization 。BN在CNN等地方用得很多。
BN的主要思想就是:在每一層的每一批資料上進行歸一化。
我們可能會對輸入資料進行歸一化,但是經過該網路層的作用後,我們的的資料已經不再是歸一化的了。隨著這種情況的發展,資料的偏差越來越大,我的反向傳播需要考慮到這些大的偏差,這就迫使我們只能使用較小的學習率來防止梯度消失或者梯度爆炸。
BN的具體做法就是對每一小批資料,在批這個方向上做歸一化。如下圖所示:

可以看到,右半邊求均值是 沿著資料批量N的方向進行的 !
Batch normalization的計算公式如下:
具體的實現可以檢視上圖的連結文章。
說完Batch normalization,就該說說咱們今天的主角 Layer normalization 。
那麼什麼是Layer normalization呢?:它也是歸一化資料的一種方式,不過LN是 在每一個樣本上計算均值和方差,而不是BN那種在批方向計算均值和方差 !
下面是LN的示意圖:

和上面的BN示意圖一比較就可以看出二者的區別啦!
下面看一下LN的公式,也BN十分相似:
Layer normalization的實現
上述兩個引數 和 都是可學習引數。下面我們自己來實現Layer normalization(PyTorch已經實現啦!)。程式碼如下:
import torch import torch.nn as nn class LayerNorm(nn.Module): """實現LayerNorm。其實PyTorch已經實現啦,見nn.LayerNorm。""" def __init__(self, features, epsilon=1e-6): """Init. Args: features: 就是模型的維度。論文預設512 epsilon: 一個很小的數,防止數值計算的除0錯誤 """ super(LayerNorm, self).__init__() # alpha self.gamma = nn.Parameter(torch.ones(features)) # beta self.beta = nn.Parameter(torch.zeros(features)) self.epsilon = epsilon def forward(self, x): """前向傳播. Args: x: 輸入序列張量,形狀為[B, L, D] """ # 根據公式進行歸一化 # 在X的最後一個維度求均值,最後一個維度就是模型的維度 mean = x.mean(-1, keepdim=True) # 在X的最後一個維度求方差,最後一個維度就是模型的維度 std = x.std(-1, keepdim=True) return self.gamma * (x - mean) / (std + self.epsilon) + self.beta 複製程式碼
順便提一句, Layer normalization 多用於RNN這種結構。
Mask是什麼?
現在終於輪到講解mask了!mask顧名思義就是 掩碼 ,在我們這裡的意思大概就是 對某些值進行掩蓋,使其不產生效果 。
需要說明的是,我們的Transformer模型裡面涉及兩種mask。分別是 padding mask 和 sequence mask 。其中後者我們已經在decoder的self-attention裡面見過啦!
其中, padding mask 在所有的scaled dot-product attention裡面都需要用到,而 sequence mask 只有在decoder的self-attention裡面用到。
所以,我們之前 ScaledDotProductAttention 的 forward
方法裡面的引數 attn_mask
在不同的地方會有不同的含義。這一點我們會在後面說明。
Padding mask
什麼是 padding mask 呢?回想一下,我們的每個批次輸入序列長度是不一樣的!也就是說,我們要對輸入序列進行 對齊 !具體來說,就是給在較短的序列後面填充 0
。因為這些填充的位置,其實是沒什麼意義的,所以我們的attention機制 不應該把注意力放在這些位置上 ,所以我們需要進行一些處理。
具體的做法是, 把這些位置的值加上一個非常大的負數(可以是負無窮),這樣的話,經過softmax,這些位置的概率就會接近0 !
而我們的padding mask實際上是一個張量,每個值都是一個 Boolen ,值為 False
的地方就是我們要進行處理的地方。
下面是實現:
def padding_mask(seq_k, seq_q): # seq_k和seq_q的形狀都是[B,L] len_q = seq_q.size(1) # `PAD` is 0 pad_mask = seq_k.eq(0) pad_mask = pad_mask.unsqueeze(1).expand(-1, len_q, -1)# shape [B, L_q, L_k] return pad_mask 複製程式碼
Sequence mask
文章前面也提到,sequence mask是為了使得decoder不能看見未來的資訊。也就是對於一個序列,在time_step為t的時刻,我們的解碼輸出應該只能依賴於t時刻之前的輸出,而不能依賴t之後的輸出。因此我們需要想一個辦法,把t之後的資訊給隱藏起來。
那麼具體怎麼做呢?也很簡單: 產生一個上三角矩陣,上三角的值全為1,下三角的值權威0,對角線也是0 。把這個矩陣作用在每一個序列上,就可以達到我們的目的啦。
具體的程式碼實現如下:
def sequence_mask(seq): batch_size, seq_len = seq.size() mask = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.uint8), diagonal=1) mask = mask.unsqueeze(0).expand(batch_size, -1, -1)# [B, L, L] return mask 複製程式碼
哈佛大學的文章 The Annotated Transformer 有一張效果圖:

值得注意的是,本來mask只需要二維的矩陣即可,但是考慮到我們的輸入序列都是批量的,所以我們要把原本二維的矩陣擴張成3維的張量。上面的程式碼可以看出,我們已經進行了處理。
回到本小結開始的問題, attn_mask
引數有幾種情況?分別是什麼意思?
- 對於decoder的self-attention,裡面使用到的scaled dot-product attention,同時需要
padding mask
和sequence mask
作為attn_mask
,具體實現就是兩個mask相加作為attn_mask。 - 其他情況,
attn_mask
一律等於padding mask
。
至此,mask相關的問題解決了。
Positional encoding是什麼?
好了,終於要解釋 位置編碼 了,那就是文字開始的結構圖提到的 Positional encoding 。
就目前而言,我們的Transformer架構似乎少了點什麼東西。沒錯,就是 它對序列的順序沒有約束 !我們知道序列的順序是一個很重要的資訊,如果缺失了這個資訊,可能我們的結果就是:所有詞語都對了,但是無法組成有意義的語句!
為了解決這個問題。論文提出了 Positional encoding 。這是啥?一句話概括就是: 對序列中的詞語出現的位置進行編碼 !如果對位置進行編碼,那麼我們的模型就可以捕捉順序資訊!
那麼具體怎麼做呢?論文的實現很有意思,使用正餘弦函式。公式如下:
其中, pos
是指詞語在序列中的位置。可以看出,在 偶數位置,使用正弦編碼,在奇數位置,使用餘弦編碼 。
上面公式中的 是模型的維度,論文預設是 512
。
這個編碼公式的意思就是: 給定詞語的位置 ,我們可以把它編碼成 維的向量 !也就是說,位置編碼的每一個維度對應正弦曲線,波長構成了從 到 的等比序列。
上面的位置編碼是 絕對位置編碼 。但是詞語的 相對位置 也非常重要。這就是論文為什麼要使用三角函式的原因!
正弦函式能夠表達相對位置資訊。,主要數學依據是以下兩個公式:
上面的公式說明,對於詞彙之間的位置偏移 k
, 可以表示成 和 的組合形式,這就是表達相對位置的能力!
以上就是 E的所有祕密。說完了positional encoding,那麼我們還有一個與之處於同一地位的 word embedding 。
Word embedding大家都很熟悉了,它是對序列中的詞彙的編碼,把每一個詞彙編碼成 維的向量!看到沒有, Postional encoding是對詞彙的位置編碼,word embedding是對詞彙本身編碼 !
所以,我更喜歡positional encoding的另外一個名字 Positional embedding !
Positional encoding的實現
PE的實現也不難,按照論文的公式即可。程式碼如下:
import torch import torch.nn as nn class PositionalEncoding(nn.Module): def __init__(self, d_model, max_seq_len): """初始化。 Args: d_model: 一個標量。模型的維度,論文預設是512 max_seq_len: 一個標量。文字序列的最大長度 """ super(PositionalEncoding, self).__init__() # 根據論文給的公式,構造出PE矩陣 position_encoding = np.array([ [pos / np.pow(10000, 2.0 * (j // 2) / d_model) for j in range(d_model)] for pos in range(max_seq_len)]) # 偶數列使用sin,奇數列使用cos position_encoding[:, 0::2] = np.sin(position_encoding[:, 0::2]) position_encoding[:, 1::2] = np.cos(position_encoding[:, 1::2]) # 在PE矩陣的第一行,加上一行全是0的向量,代表這`PAD`的positional encoding # 在word embedding中也經常會加上`UNK`,代表位置單詞的word embedding,兩者十分類似 # 那麼為什麼需要這個額外的PAD的編碼呢?很簡單,因為文字序列的長度不一,我們需要對齊, # 短的序列我們使用0在結尾補全,我們也需要這些補全位置的編碼,也就是`PAD`對應的位置編碼 pad_row = torch.zeros([1, d_model]) position_encoding = torch.cat((pad_row, position_encoding)) # 嵌入操作,+1是因為增加了`PAD`這個補全位置的編碼, # Word embedding中如果詞典增加`UNK`,我們也需要+1。看吧,兩者十分相似 self.position_encoding = nn.Embedding(max_seq_len + 1, d_model) self.position_encoding.weight = nn.Parameter(position_encoding, requires_grad=False) def forward(self, input_len): """神經網路的前向傳播。 Args: input_len: 一個張量,形狀為[BATCH_SIZE, 1]。每一個張量的值代表這一批文字序列中對應的長度。 Returns: 返回這一批序列的位置編碼,進行了對齊。 """ # 找出這一批序列的最大長度 max_len = torch.max(input_len) tensor = torch.cuda.LongTensor if input_len.is_cuda else torch.LongTensor # 對每一個序列的位置進行對齊,在原序列位置的後面補上0 # 這裡range從1開始也是因為要避開PAD(0)的位置 input_pos = tensor( [list(range(1, len + 1)) + [0] * (max_len - len) for len in input_len]) return self.position_encoding(input_pos) 複製程式碼
Word embedding的實現
Word embedding應該是老生常談了,它實際上就是一個二維浮點矩陣,裡面的權重是可訓練引數,我們只需要把這個矩陣構建出來就完成了word embedding的工作。
所以,具體的實現很簡單:
import torch.nn as nn embedding = nn.Embedding(vocab_size, embedding_size, padding_idx=0) # 獲得輸入的詞嵌入編碼 seq_embedding = seq_embedding(inputs)*np.sqrt(d_model) 複製程式碼
上面 vocab_size
就是詞典的大小, embedding_size
就是詞嵌入的維度大小,論文裡面就是等於 。所以word embedding矩陣就是一個 vocab_size
* embedding_size
的二維張量。
如果你想獲取更詳細的關於word embedding的資訊,可以看我的另外一個文章 word2vec的筆記和實現 。
Position-wise Feed-Forward network是什麼?
這就是一個全連線網路,包含兩個線性變換和一個非線性函式(實際上就是ReLU)。公式如下:
這個線性變換在不同的位置都表現地一樣,並且在不同的層之間使用不同的引數。
論文提到,這個公式還可以用兩個核大小為1的一維卷積來解釋,卷積的輸入輸出都是 ,中間層的維度是 。
實現如下:
import torch import torch.nn as nn class PositionalWiseFeedForward(nn.Module): def __init__(self, model_dim=512, ffn_dim=2048, dropout=0.0): super(PositionalWiseFeedForward, self).__init__() self.w1 = nn.Conv1d(model_dim, ffn_dim, 1) self.w2 = nn.Conv1d(model_dim, ffn_dim, 1) self.dropout = nn.Dropout(dropout) self.layer_norm = nn.LayerNorm(model_dim) def forward(self, x): output = x.transpose(1, 2) output = self.w2(F.relu(self.w1(output))) output = self.dropout(output.transpose(1, 2)) # add residual and norm layer output = self.layer_norm(x + output) return output 複製程式碼
Transformer的實現
至此,所有的細節都已經解釋完了。現在來完成我們Transformer模型的程式碼。
首先,我們需要實現6層的encoder和decoder。
encoder程式碼實現如下:
import torch import torch.nn as nn class EncoderLayer(nn.Module): """Encoder的一層。""" def __init__(self, model_dim=512, num_heads=8, ffn_dim=2018, dropout=0.0): super(EncoderLayer, self).__init__() self.attention = MultiHeadAttention(model_dim, num_heads, dropout) self.feed_forward = PositionalWiseFeedForward(model_dim, ffn_dim, dropout) def forward(self, inputs, attn_mask=None): # self attention context, attention = self.attention(inputs, inputs, inputs, padding_mask) # feed forward network output = self.feed_forward(context) return output, attention class Encoder(nn.Module): """多層EncoderLayer組成Encoder。""" def __init__(self, vocab_size, max_seq_len, num_layers=6, model_dim=512, num_heads=8, ffn_dim=2048, dropout=0.0): super(Encoder, self).__init__() self.encoder_layers = nn.ModuleList( [EncoderLayer(model_dim, num_heads, ffn_dim, dropout) for _ in range(num_layers)]) self.seq_embedding = nn.Embedding(vocab_size + 1, model_dim, padding_idx=0) self.pos_embedding = PositionalEncoding(model_dim, max_seq_len) def forward(self, inputs, inputs_len): output = self.seq_embedding(inputs) output += self.pos_embedding(inputs_len) self_attention_mask = padding_mask(inputs, inputs) attentions = [] for encoder in self.encoder_layers: output, attention = encoder(output, self_attention_mask) attentions.append(attention) return output, attentions 複製程式碼
通過文章前面的分析,程式碼不需要更多解釋了。同樣的,我們的decoder程式碼如下:
import torch import torch.nn as nn class DecoderLayer(nn.Module): def __init__(self, model_dim, num_heads=8, ffn_dim=2048, dropout=0.0): super(DecoderLayer, self).__init__() self.attention = MultiHeadAttention(model_dim, num_heads, dropout) self.feed_forward = PositionalWiseFeedForward(model_dim, ffn_dim, dropout) def forward(self, dec_inputs, enc_outputs, self_attn_mask=None, context_attn_mask=None): # self attention, all inputs are decoder inputs dec_output, self_attention = self.attention( dec_inputs, dec_inputs, dec_inputs, self_attn_mask) # context attention # query is decoder's outputs, key and value are encoder's inputs dec_output, context_attention = self.attention( enc_outputs, enc_outputs, dec_output, context_attn_mask) # decoder's output, or context dec_output = self.feed_forward(dec_output) return dec_output, self_attention, context_attention class Decoder(nn.Module): def __init__(self, vocab_size, max_seq_len, num_layers=6, model_dim=512, num_heads=8, ffn_dim=2048, dropout=0.0): super(Decoder, self).__init__() self.num_layers = num_layers self.decoder_layers = nn.ModuleList( [DecoderLayer(model_dim, num_heads, ffn_dim, dropout) for _ in range(num_layers)]) self.seq_embedding = nn.Embedding(vocab_size + 1, model_dim, padding_idx=0) self.pos_embedding = PositionalEncoding(model_dim, max_seq_len) def forward(self, inputs, inputs_len, enc_output, context_attn_mask=None): output = self.seq_embedding(inputs) output += self.pos_embedding(inputs_len) self_attention_padding_mask = padding_mask(inputs, inputs) seq_mask = sequence_mask(inputs) self_attn_mask = torch.gt((self_attention_padding_mask + seq_mask), 0) self_attentions = [] context_attentions = [] for decoder in self.decoder_layers: output, self_attn, context_attn = decoder( output, enc_output, self_attn_mask, context_attn_mask) self_attentions.append(self_attn) context_attentions.append(context_attn) return output, self_attentions, context_attentions 複製程式碼
最後,我們把encoder和decoder組成Transformer模型!
程式碼如下:
import torch import torch.nn as nn class Transformer(nn.Module): def __init__(self, src_vocab_size, src_max_len, tgt_vocab_size, tgt_max_len, num_layers=6, model_dim=512, num_heads=8, ffn_dim=2048, dropout=0.2): super(Transformer, self).__init__() self.encoder = Encoder(src_vocab_size, src_max_len, num_layers, model_dim, num_heads, ffn_dim, dropout) self.decoder = Decoder(tgt_vocab_size, tgt_max_len, num_layers, model_dim, num_heads, ffn_dim, dropout) self.linear = nn.Linear(model_dim, tgt_vocab_size, bias=False) self.softmax = nn.Softmax(dim=2) def forward(self, src_seq, src_len, tgt_seq, tgt_len): context_attn_mask = padding_mask(tgt_seq, src_seq) output, enc_self_attn = self.encoder(src_seq, src_len) output, dec_self_attn, ctx_attn = self.decoder( tgt_seq, tgt_len, output, context_attn_mask) output = self.linear(output) output = self.softmax(output) return output, enc_self_attn, dec_self_attn, ctx_attn 複製程式碼
至此,Transformer模型已經實現了!