einsum滿足你一切需要:深度學習中的愛因斯坦求和約定
作者:Tim Rocktäschel
編譯:weakish
【編者按】FAIR研究科學家Tim Rocktäschel簡要介紹了einsum表示法的概念,並通過真例項子展示了einsum的表達力。
當我和同事聊天的時候,我意識到不是所有人都瞭解 einsum ,我開發深度學習模型時最喜歡的函式。本文打算改變這一現狀,讓所有人都瞭解它!愛因斯坦求和約定(einsum)在numpy和TensorFlow之類的深度學習庫中都有實現,感謝 ofollow,noindex">Thomas Viehmann ,最近PyTorch也實現了這一函式。關於einsum的背景知識,我推薦閱讀Olexa Bilaniuk的 numpy的愛因斯坦求和約定 以及Alex Riley的einsum基本指南。這兩篇文章介紹了numpy中的einsum,我的這篇文章則將演示在編寫優雅的PyTorch/TensorFlow模型時,einsum是多麼有用(我將使用PyTorch作為例子,不過很容易就可以翻譯到TensorFlow)。
1. einsum記法
如果你像我一樣,發現記住PyTorch/TensorFlow中那些計算點積、外積、轉置、矩陣-向量乘法、矩陣-矩陣乘法的函式名字和簽名很費勁,那麼einsum記法就是我們的救星。einsum記法是一個表達以上這些運算,包括複雜張量運算在內的優雅方式,基本上,可以把einsum看成一種領域特定語言。一旦你理解並能利用einsum,除了不用記憶和頻繁查詢特定庫函式這個好處以外,你還能夠更迅速地編寫更加緊湊、高效的程式碼。而不使用einsum的時候,容易出現引入不必要的張量變形或轉置運算,以及可以省略的中間張量的現象。此外,einsum這樣的領域特定語言有時可以編譯到高效能程式碼,事實上,PyTorch最近引入的能夠自動生成GPU程式碼併為特定輸入尺寸自動調整程式碼的張量理解(Tensor Comprehensions)就基於類似einsum的領域特定語言。此外,可以使用 opt einsum 和 tf einsum opt 這樣的專案優化einsum表示式的構造順序。
比方說,我們想要將兩個矩陣 和
相乘,接著計算每列的和,最終得到向量
。使用愛因斯坦求和約定,這可以表達為:

這一表達式指明瞭c中的每個元素 是如何計算的,列向量
乘以行向量
,然後求和。注意,在愛因斯坦求和約定中,我們省略了求和符號Sigma,因為我們隱式地累加重複的下標(這裡是k)和輸出中未指明的下標(這裡是i)。當然,einsum也能表達更基本的運算。比如,計算兩個向量
的點積可以表達為:

在深度學習中,我經常碰到的一個問題是,變換高階張量到向量。例如,我可能有一個張量,其中包含一個batch中的N個訓練樣本,每個樣本是一個長度為T的K維詞向量序列,我想把詞向量投影到一個不同的維度Q。如果將這個張量記作 ,將投影矩陣記作
,那麼所需計算可以用einsum表達為:

最後一個例子,比方說有一個四階張量 ,我們想要使用之前的投影矩陣將第三維投影至Q維,並累加第二維,然後轉置結果中的第一維和最後一維,最終得到張量
。einsum可以非常簡潔地表達這一切:

注意,我們通過交換下標n和m( 而不是
),轉置了張量構造結果。
2. Numpy、PyTorch、TensorFlow中的einsum
einsum在numpy中實現為 np.einsum
,在PyTorch中實現為 torch.einsum
,在TensorFlow中實現為 tf.einsum
,均使用一致的簽名 einsum(equation, operands)
,其中 equation
是表示愛因斯坦求和約定的字串,而 operands
則是張量序列(在numpy和TensorFlow中是變長引數列表,而在PyTorch中是列表)。例如,我們的第一個例子,cj = ∑i∑kAikBkj寫成 equation
字串就是 ik,kj -> j
。注意這裡 (i, j, k)
的命名是任意的,但需要一致。
PyTorch和TensorFlow像numpy支援einsum的好處之一是einsum可以用於神經網路架構的任意計算圖,並且可以反向傳播。典型的einsum呼叫格式如下:

上式中◻是佔位符,表示張量維度。上面的例子中,arg1和arg3是矩陣,arg2是二階張量,這一einsum運算的結果(result)是矩陣。注意einsum處理的是可變數量的輸入。在上面的例子中,einsum指定了三個引數之上的操作,但它同樣可以用在牽涉一個引數、兩個引數、三個以上引數的操作上。學習einsum的最佳途徑是通過學習一些例子,所以下面我們將展示一下,在許多深度學習模型中常用的庫函式,用einsum該如何表達(以PyTorch為例)。
2.1 矩陣轉置

import torch a = torch.arange(6).reshape(2, 3) torch.einsum('ij->ji', [a]) tensor([[ 0.,3.], [ 1.,4.], [ 2.,5.]])
2.2 求和

a = torch.arange(6).reshape(2, 3) torch.einsum('ij->', [a]) tensor(15.)
2.3 列求和

a = torch.arange(6).reshape(2, 3) torch.einsum('ij->j', [a]) tensor([ 3.,5.,7.])
2.4 行求和

a = torch.arange(6).reshape(2, 3) torch.einsum('ij->i', [a]) tensor([3.,12.])
2.5 矩陣-向量相乘

a = torch.arange(6).reshape(2, 3) b = torch.arange(3) torch.einsum('ik,k->i', [a, b]) tensor([5.,14.])
2.6 矩陣-矩陣相乘

a = torch.arange(6).reshape(2, 3) b = torch.arange(15).reshape(3, 5) torch.einsum('ik,kj->ij', [a, b]) tensor([[25.,28.,31.,34.,37.], [70.,82.,94.,106.,118.]])
2.7 點積
向量:

a = torch.arange(3) b = torch.arange(3,6)# [3, 4, 5] torch.einsum('i,i->', [a, b]) tensor(14.)
矩陣:

a = torch.arange(6).reshape(2, 3) b = torch.arange(6,12).reshape(2, 3) torch.einsum('ij,ij->', [a, b]) tensor(145.)
2.8 哈達瑪積

a = torch.arange(6).reshape(2, 3) b = torch.arange(6,12).reshape(2, 3) torch.einsum('ij,ij->ij', [a, b]) tensor([[0.,7.,16.], [ 27.,40.,55.]])
2.9 外積

a = torch.arange(3) b = torch.arange(3,7) torch.einsum('i,j->ij', [a, b]) tensor([[0.,0.,0.,0.], [3.,4.,5.,6.], [6.,8.,10.,12.]])
2.10 batch矩陣相乘

a = torch.randn(3,2,5) b = torch.randn(3,5,3) torch.einsum('ijk,ikl->ijl', [a, b]) tensor([[[ 1.0886,0.0214,1.0690], [ 2.0626,3.2655, -0.1465]], [[-6.9294,0.7499,1.2976], [ 4.2226, -4.5774, -4.8947]], [[-2.4289, -0.7804,5.1385], [ 0.8003,2.9425,1.7338]]])
2.11 張量縮約
batch矩陣相乘是張量縮約的一個特例。比方說,我們有兩個張量,一個n階張量 ,一個m階張量
。舉例來說,我們取n = 4,m = 5,並假定
且
。我們可以將這兩個張量在這兩個維度上相乘(A張量的第2、3維度,B張量的3、5維度),最終得到一個新張量
,如下所示:

a = torch.randn(2,3,5,7) b = torch.randn(11,13,3,17,5) torch.einsum('pqrs,tuqvr->pstuv', [a, b]).shape torch.Size([2, 7, 11, 13, 17])
2.12 雙線性變換
如前所述,einsum可用於超過兩個張量的計算。這裡舉一個這方面的例子,雙線性變換。

a = torch.randn(2,3) b = torch.randn(5,3,7) c = torch.randn(2,7) torch.einsum('ik,jkl,il->ij', [a, b, c]) tensor([[ 3.8471,4.7059, -3.0674, -3.2075, -5.2435], [-3.5961, -5.2622, -4.1195,5.5899,0.4632]])
3. 案例
3.1 TreeQN
我曾經在實現TreeQN( arXiv:1710.11417)的等式6時使用了einsum:給定網路層l上的低維狀態表示zl,和啟用a上的轉換函式Wa,我們想要計算殘差連結的下一層狀態表示。

在實踐中,我們想要高效地計算大小為B的batch中的K維狀態表示 ,並同時計算所有轉換函式(即,所有啟用A)。我們可以將這些轉換函式安排為一個張量
,並使用einsum高效地計算下一層狀態表示。
import torch.nn.functional as F def random_tensors(shape, num=1, requires_grad=False): tensors = [torch.randn(shape, requires_grad=requires_grad) for i in range(0, num)] return tensors[0] if num == 1 else tensors # 引數 # -- [啟用數 x 隱藏層維度] b = random_tensors([5, 3], requires_grad=True) # -- [啟用數 x 隱藏層維度 x 隱藏層維度] W = random_tensors([5, 3, 3], requires_grad=True) def transition(zl): # -- [batch大小 x 啟用數 x 隱藏層維度] return zl.unsqueeze(1) + F.tanh(torch.einsum("bk,aki->bai", [zl, W]) + b) # 隨機取樣仿造輸入 # -- [batch大小 x 隱藏層維度] zl = random_tensors([2, 3]) transition(zl)
3.2 注意力
讓我們再看一個使用einsum的真例項子,實現注意力機制的等式11-13(arXiv:1509.06664):

用傳統寫法實現這些可要費不少力氣,特別是考慮batch實現。einsum是我們的救星!
# 引數 # -- [隱藏層維度] bM, br, w = random_tensors([7], num=3, requires_grad=True) # -- [隱藏層維度 x 隱藏層維度] WY, Wh, Wr, Wt = random_tensors([7, 7], num=4, requires_grad=True) # 注意力機制的單次應用 def attention(Y, ht, rt1): # -- [batch大小 x 隱藏層維度] tmp = torch.einsum("ik,kl->il", [ht, Wh]) + torch.einsum("ik,kl->il", [rt1, Wr]) Mt = F.tanh(torch.einsum("ijk,kl->ijl", [Y, WY]) + tmp.unsqueeze(1).expand_as(Y) + bM) # -- [batch大小 x 序列長度] at = F.softmax(torch.einsum("ijk,k->ij", [Mt, w])) # -- [batch大小 x 隱藏層維度] rt = torch.einsum("ijk,ij->ik", [Y, at]) + F.tanh(torch.einsum("ij,jk->ik", [rt1, Wt]) + br) # -- [batch大小 x 隱藏層維度], [batch大小 x 序列維度] return rt, at # 取樣仿造輸入 # -- [batch大小 x 序列長度 x 隱藏層維度] Y = random_tensors([3, 5, 7]) # -- [batch大小 x 隱藏層維度] ht, rt1 = random_tensors([3, 7], num=2) rt, at = attention(Y, ht, rt1)
4. 總結
einsum是 一個函式走天下 ,是處理各種張量操作的瑞士軍刀。話雖如此,“einsum滿足你一切需要”顯然誇大其詞了。從上面的真實用例可以看到,我們仍然需要在einsum之外應用非線性和構造額外維度( unsqueeze
)。類似地,分割、連線、索引張量仍然需要應用其他庫函式。
使用einsum的麻煩之處是你需要手動例項化引數,操心它們的初始化,並在模型中註冊這些引數。不過我仍然強烈建議你在實現模型時,考慮下有哪些情況適合使用einsum.