[原始碼解析] PyTorch 流水線並行實現 (4)--前向計算

0x00 摘要

前幾篇文章我們介紹了 PyTorch 流水線並行的基本知識,自動平衡機制和切分資料,本文我們結合論文內容來看看如何保證前向計算執行順序。

流水線並行其他文章連結如下:

[原始碼解析] 深度學習流水線並行Gpipe(1)---流水線基本實現

[原始碼解析] 深度學習流水線並行GPipe (2) ----- 梯度累積

[原始碼解析] 深度學習流水線並行 GPipe(3) ----重計算

[原始碼解析] 深度學習流水線並行之PipeDream(1)--- Profile階段

[原始碼解析] 深度學習流水線並行 PipeDream(2)--- 計算分割槽

[原始碼解析] 深度學習流水線並行 PipeDream(3)--- 轉換模型

[原始碼解析] 深度學習流水線並行 PipeDream(4)--- 執行時引擎

[原始碼解析] 深度學習流水線並行 PipeDream(5)--- 通訊模組

[原始碼解析] 深度學習流水線並行 PipeDream(6)--- 1F1B策略

[原始碼解析] PyTorch 流水線並行實現 (1)--基礎知識

[原始碼解析] PyTorch 流水線並行實現 (2)--如何劃分模型

[原始碼解析] PyTorch 流水線並行實現 (3)--切分資料和執行時系統

本文圖片來自論文和github原始碼。

0x01 論文

之前我們提到過,因為 GPipe 是基於 TensorFlow 的庫(當然了,這是Google的產品嘛),所以kakaobrain的一些工程師就用PyTorch 來實現了 GPipe,並且開源出來,這就是 torchgpipe,其地址為:https://github.com/kakaobrain/torchgpipe,使用者可以通過 pip install torchgpipe 進行安裝使用。

該作者團隊還發表了一篇論文,具體如下:https://arxiv.org/pdf/2004.09910.pdf。

接下來我們就圍繞這篇論文進行分析,本文不會全部翻譯這篇論文,而是選擇與實現密切相關的部分進行翻譯分析。

1.1 引論

並行訓練的一個障礙是:訓練神經網路的常用優化技術本質上是順序的。這些演算法反覆執行如下操作:對於給定的小批次(mini-batch)資料,計算其針對損失函式的梯度,並且使用這些梯度來更新模型引數。

1.1.1 資料並行

在有大量計算資源的情況下,資料並行將小批量(mini-batch)劃分為微批量(micro-batch)並將每個微批量的計算委託給可用裝置,以此來加速整體優化過程。通過仔細的超引數調整,資料並行可以有效地將訓練時間減少到一定規模的小批量所需的訓練時間,這可能取決於模型、優化演算法和資料。

資料並行訓練的問題則是,每個裝置擁有自己的模型網路版本來執行子任務,並且在每次引數更新後必須同步模型網路引數。當有許多引數需要同步時,這可能會導致沉重的通訊負載。

但是,當模型太大以至於即使將單個機器無法容納模型,也無法計算梯度時,資料並行性不適用。

1.1.2 模型並行

模型並行性是一種訓練龐大模型的方法,它將模型分成若干部分,並將它們放在不同的裝置上。每個裝置只計算模型的一小部分,並且只更新該部分中的引數。然而,模型並行性受到其"無法充分利用"行為的影響。因為大多數神經網路由一系列的層組成,持有模型後期部分的裝置必須等待直到持有模型早期部分的裝置計算結束。

一種可能的解決方案是使用梯度檢查點,它只儲存啟用值的子集,並在需要時重新計算丟棄的啟用值,從而節省記憶體。顯然,這需要對模型的某些部分進行兩次計算,並增加總體訓練時間。

在後續部分,我們將討論如何將前向和後向過程分解為子任務(在某些假設下),描述微批次管道並行的裝置分配策略,並演示每個裝置所需的執行順序。也會討論在PyTorch中實現管道並行最佳時間線的複雜之處,並解釋torchgpipe如何解決這些問題。

此外,我們放鬆了模型是按順序組合的假設,並提供了一種使用長跳躍連線表示模型的方法,以便在不放棄效率的情況下仍然應用管道並行性。

1.2 模型定義

假定我們有一個神經網路,其由一系列子網路構成。我們假定這些子網路是 \(f^1,...,f^n\),其引數分別是 \(\theta^1,...,\theta^n\),則整個網路是:

引數是 \(\theta = (\theta^1,...,\theta^n)\),為了清楚起見,我們稱 \(f^j\) 表示 f 的第 j 個分割槽,並假設分割槽的引數是相互不相交的。

在訓練網路時,基於梯度的方法(如隨機梯度下降法)需要在給定小批量訓練資料 x 和相應損失之後,計算網路的輸出結果f(x)。以及損失相對於網路引數 \(\theta\) 的梯度g。這兩個階段分別稱為向前傳播和向後傳播。

既然 f 由其 L 層 子模組 (\(f^L, f^{L-1},...f^1\)) 順序組成,那麼前向傳播\(f(x)\) 可以通過如下方式計算:讓 \(x^0=x\)(就是輸入x),然後順序應用每一個 partition,即 \(x^j = f^j (x^{j-1})\),這裡 $ j = 1, ..., L$。就是 \(f(x)\) 可以表示為 :

\[f(x) = f^L(f^{L-1}(f^{L-2}(... f^1(x))))
\]

再進一步,令 x 由 m 個更小的批次 \(x_1,...,x_m\) 組成,這些更小的批次叫做微批次(micro-batches)。則\(f(x)\) 的計算可以進一步分割為小的 tasks \(F_{i,j}\),這裡 \(x^0_i = x_i\),所以得到定義:

\[x_i^j \leftarrow f^j(x_i^{j-1})\qquad\qquad\qquad (F_{i,j})
\]

這裡 $ i = 1,..,m$ 和 $ j = 1,...,n$,假定 f 不參與任何 intra-batch 的計算。

用同樣的方式,後向傳播也被分割為 task,\(B_{i,j}\),這裡 \(dx^n_j\) 是損失對於 \(x_j^n\) 的梯度。

因此

是通過分割槽 \(f^j\) 來計算後向傳播(也叫vector-Jacobian product)的函式。

最終,我們依靠把 \(g_i^j\) 通過 i 來求和來得到損失針對 \(\theta ^ j\) 的梯度。

需要注意的是在tasks之間有資料依賴。比如 \(F_{i,j}\) 需要 \(x_i^{j-1}\),而\(x_i^{j-1}\) 只有在 \(F_{i,j-1}\) 計算完成之後才有效,因此,\(F_{i,j-1}\) 必須在 \(F_{i,j}\) 開始之前結束。同理, \(B_{i,j}\) 必須在 \(B_{i,j+1}\) 之前結束。

下圖就是一個依賴圖,這裡 $ m = 4, n = 3$。即,模型被分成3個子網路,小批次被分割成 4個微批次。

前面三個 F 是三個子網路的前向傳播,後面三個 B 是三個子網路的後向傳播。

下面表示第一個微批次,順序完成三個子網的前向傳播和後向傳播。

給定 task 的集合 \({F_{i,j}}\) 和 \({B_{i,j}}\) ,和一個可以並行工作的裝置池,不同的並行化策略有自己分配任務給裝置的規則。

一旦解決依賴關係,每個裝置就會計算一個或多個分配的任務。在上面的設定中,tasks 的所有依賴項都具有相同微批次索引 i。因此,通過將具有不同微批量索引的任務分配給不同的裝置,可以有效地並行化任務,這就是資料並行。

1.3 GPipe計算圖

管道並行的策略是根據分割槽索引 j 分配任務,以便第 j 個分割槽完全位於第 j 個裝置中。除此之外,還強制要求 \(F_{i,j}\) 必須在 \(F_{i+1,j}\) 之前完成,和 \(B{i,j}\) 必須在執行\(B{i-1,j}\)之前完成。

除了微批量流水線之外,GPipe還通過對每個\(B{i,j}\) 使用梯度檢查點進一步降低了記憶體需求。因為第 \(jth\) 個裝置每次只執行 \(B{i,j}\) ,所以當計算\(B{i,j}\) 時候,只需要拿到 \(F{i,j}\) 的啟用圖。

因為恰恰在執行 \(B{i,j}\) 之前計算前向傳播\(F{i,j}\),所以我們記憶體消耗減少了m倍。此外,當裝置等待\(B{i,j}\) 時,可以進行重新計算,這些資訊如下圖所示:

其中虛線箭頭表示因為引入了微批次順序而帶來的獨立任務之間的執行順序。顏色表示不同的裝置。

我們注意到最後一個微批次的重新計算,即 \(F^{'}_{m,j}\) ,這裡 \(j=1,...,n\)是不必要的。

這是因為在第j臺裝置上,前向傳遞中的最後一個任務是 \(F{m,j}\) 、 因此,在前向傳遞中放棄中間啟用,並在後向傳遞開始時重新計算它們,不會減少記憶體,只會減慢管道速度。因此,圖中省略了 \(F^{'}_{m,j}\) 。

1.4 裝置執行順序(Devicewise Execution Order)

總之,在流水線並行性(帶有檢查點)中,每個裝置都被分配了一組具有指定順序的任務。一旦滿足跨裝置依賴關係,每個裝置將逐個執行給定的任務。然而,這個圖片中缺少一個元件——裝置之間的資料傳輸。為了便於說明,裝置 j 必須遵循的完整執行如圖所示順序。而且為了強調,資料傳輸操作被明確表示為“接收”和“傳送”。

為方便起見,庫提供了子模組 torchgpipe.balance 來計算得到分割槽,目的是讓兩兩分割槽(pairwise)之間的資源差別儘量小。資源佔用情況是通過分析(profile)來計算。具體是使用了 [2] Imre B´ar´any and Victor S Grinberg. Block partitions of sequences. Israel Journal of Mathematics, 206(1):155–164, 之中的演算法。

1.5 PyTorch 實現難點

我們最關心的是效率。為了使管道並行性按預期工作,必須以正確的順序將任務分配給每個裝置。在Pytorch中實現這一點有幾個複雜之處。

  • 首先,由於PyTorch的define by run風格及其eager execution的執行行為(與in construct-and-run 框架相反),核函式(kernel)被動態地釋出到每個裝置。

    • 因此,必須仔細設計主機程式碼(host code),這樣不僅可以在每個裝置中以正確的順序釋出繫結到裝置的任務,而且還可以避免由於Python直譯器未能提前請求而延遲在裝置上(與CPU非同步)執行任務。
    • 當某些任務是CPU密集型任務或涉及大量廉價kernel呼叫時,可能會發生這種延遲。作為一種解決方案,torchgpipe引入了確定性時鐘週期(deterministic clock-cycle),它給出了任務的總體順序
  • 其次,後向傳播的計算圖是在前向傳播過程中動態構造的。換句話說,“它避免了“正向圖”的具體化,只記錄微分計算所需的內容”。因為PyTorch既不記錄正向計算圖,也不維護一個梯度磁帶(gradient tape),PyTorch的自動微分(autograd)引擎僅對計算圖進行反向傳播。這意味著自動載入引擎可能不會完全按照與正向過程相反的執行順序執行,除非由圖的結構強制執行。為了解決這個問題,torchgpipe開發了一對名為“fork”和“join”的基本函式,在後向計算圖中動態建立顯式依賴關係

  • 第三,如果不小心管理,多個裝置之間的通訊可能導致雙向同步。這會導致利用率不足,因為即使在副本和佇列中的下一個任務之間沒有顯式依賴關係時,傳送方也可能等待與接收方同步,反之亦然。torchgpipe通過使用非預設CUDA流避免了這個問題,這樣副本就不會阻止計算,除非計算必須等待資料

  • 最後,torchgpipe試圖放寬微批處理流水線並行性的限制(模型必須是順序的)。

    • 儘管原則上任何神經網路都可以以順序形式編寫,但這需要提前知道整個計算圖,而PyTorch中則不是這樣。特別是,如果有一個張量從裝置 \(j^{'}\)中的一層跳到裝置 \(j>j^{'}+1\)中的另一層,則該張量將被複制到中間的所有裝置,因為torchgpipe無法提前知道它。為了避免這個問題,我們設計了一個介面來表示跳過了哪些中間張量以及哪些層使用了它們。

1.6 總結

我們總結一下目前核心難度,從而引入下面的工作。

  • 原始流水線狀態如下:

    • 管道並行的策略是根據分割槽索引 j 分配任務,以便第 j 個分割槽完全位於第 j 個裝置中。
    • 持有模型後期部分的裝置必須等待,直到持有模型早期部分的裝置計算結束。

  • 目標流水線狀態如下:

  • 目前問題

    • 如果分成若干個微批次,則需要強制要求 \(F_{i,j}\) 必須在 \(F_{i+1,j}\) 之前完成,以及 \(B{i,j}\) 必須在執行\(B{i-1,j}\) 之前完成。
    • 後向傳播的計算圖是在前向傳播過程中動態構造的。PyTorch既不記錄正向計算圖,也不維護一個梯度磁帶(gradient tape),PyTorch的自動微分(autograd)引擎僅對計算圖進行反向傳播。這意味著自動載入引擎可能不會完全按照與正向過程相反的執行順序執行,除非由圖的結構強制執行

  • 目前難點

    • 如何在每個裝置中以正確的順序釋出那些繫結到裝置的任務,以避免由於Python直譯器未能提前請求而延遲在裝置上(與CPU非同步)執行任務。
    • 如何建立這些小批次之間的跨裝置依賴關係
  • 實現方案

    • 如何保證正確執行順序?torchgpipe引入了確定性時鐘週期(deterministic clock-cycle),它給出了任務的總體順序
    • 如何保證計算圖中的動態顯式依賴關係?針對clock_cycles產生的每一個執行計劃:
      • 利用 fence 函式呼叫“fork”和“join”,以此在向後計算圖中動態建立顯式後向傳播依賴關係。
      • 利用 compute(schedule, skip_trackers, in_queues, out_queues) 進行計算。

本文就首先看看前向計算中,如何保證正確執行順序。

0x02 執行順序

下面我們看看確定性時鐘週期演算法(Forward Dependency: Deterministic Clock-cycle)。這個排序就是專門在前向傳播中使用,前向傳播按照這個演算法來進行逐一計算

一般來說,前向傳播計算是按照模型結構來完成的,但是因為流水線並行是特殊的,模型已經被分割開了,所以 torch-gpipe 需要自己提供一個前向傳播執行序列以執行各個微批次

2.1 論文內容

任務的總順序由前向傳播中的主機程式碼決定。每個裝置通過CPU分配的順序隱式地理解任務之間的依賴關係。理想情況下,如果可以無代價的將任務分配給裝置,只要裝置內的順序正確,CPU就可以按任何順序將任務分配給裝置。然而,這種假設不夠現實,因為在GPU上啟動核函式對CPU來說不是免費的,GPU之間的記憶體傳輸可能需要同步,或者任務是CPU密集型的。因此,為了最小化來自CPU的延遲,我們通過"某節點到\(F_{1,1}\)的距離"對所有任務進行排序

我們把這種方案命名為確定性時鐘週期(deterministic clock-cycle)演算法。在該演算法中,CPU在計數器 \(k=1\)到\(k=m+n-1\) 的時鐘週期內執行。在第k個時鐘週期內,對於 $ i +j-1 = k$ 這些index:

  • 首先執行 task \(F{i,j}\) 所需資料的所有複製(copy)核函式。
  • 然後將用於執行任務的計算核函式註冊到相應的裝置(由於同一時鐘週期中的任務是獨立的,因此可以安全地進行多執行緒處理)。

2.2 解析

我們結合論文的圖片看看,即:

  • clock 1 時候,執行圖上的 \(F_{1,1}\)
  • clock 2 時候,執行圖上的 \(F_{2,1},F_{1,2}\)。就是向右執行一格到\(F_{1,2}\),同時第二個微批次進入訓練,即執行\(F_{2,1}\) 。
  • clock 3 時候,執行圖上的 \(F_{3,1},F_{2,2},F_{1,3}\)。就是 \(F_{1,2}\)向右執行一格到\(F_{1,3}\), \(F_{2,1}\) 向右執行一格到 \(F_{2,3}\) ,同時第三個微批次進入訓練流程,即執行 \(F_{3,1}\) 。
  • clock 4 時候,執行圖上的 \(F_{4,1},F_{3,2},F_{2,3}\)。就是 \(F_{2,2}\)向右執行一格到\(F_{2,3}\), \(F_{3,1}\) 向右執行一格到 \(F_{3,2}\) ,同時第四個微批次進入訓練流程,即執行 \(F_{4,1}\) 。
  • 依次類推.....

對應到圖上,我們可以看到,

  • \(F_{2,1},F_{1,2}\) 到 \(F_{1,1}\) 的步進距離是1,走一步可到。

  • \(F_{3,1},F_{2,2},F_{1,3}\) 到 \(F_{1,1}\) 的步進距離是2,分別走兩步可到。

這個邏輯從下圖可以清晰看到。所以,這個clock的演算法就是 利用任務到\(F_{1,1}\)的距離對所有任務進行排序。這個很像把一塊石頭投入水中,泛起的水波紋一樣,從落水點一層一層的從近處向遠處傳播

這裡顏色表示不同的裝置。

2.3 程式碼

我們再來看看程式碼。首先是生成時鐘週期,這裡:

  • min(1+k, n) 就是在 k 時鐘時候,可以啟動的最大device數目(partition)。
  • max(1+k-m, 0) 就是在 k 時鐘時候,可以啟動的最小微batch(micro-batch)。

所以最終返回的序列就是k 時鐘時候,可以啟動的(index of micro-batch,index of partition)序列。

def clock_cycles(m: int, n: int) -> Iterable[List[Tuple[int, int]]]:
"""Generates schedules for each clock cycle."""
# m: number of micro-batches
# n: number of partitions
# i: index of micro-batch
# j: index of partition
# k: clock number
#
# k (i,j) (i,j) (i,j)
# - ----- ----- -----
# 0 (0,0)
# 1 (1,0) (0,1)
# 2 (2,0) (1,1) (0,2)
# 3 (2,1) (1,2)
# 4 (2,2)
# 我們解析一下,這裡 k 就是時鐘數,從1開始,最多時鐘序號就是 m+n-1。
# min(1+k, n) 就是在 k 時鐘時候,可以啟動的最大device數目
# max(1+k-m, 0) 就是在 k 時鐘時候,可以啟動的最小batch
for k in range(m+n-1):
yield [(k-j, j) for j in range(max(1+k-m, 0), min(1+k, n))]

設定 m = 4, n =3,solve(4,3) 的輸出是:

[(0, 0)]
[(1, 0), (0, 1)]
[(2, 0), (1, 1), (0, 2)]
[(3, 0), (2, 1), (1, 2)]
[(3, 1), (2, 2)]
[(3, 2)]

因為論文有一個示例圖,而這個圖和註釋&程式碼不完全一致,為了更好的說明,我們就按照圖上來,因為圖片是從 \(F_{1,1}\)開始,所以我們把註釋修正以下:

# 0 (0,0)                   ----> clock 1 執行圖上的 (1,1)
# 1 (1,0) (0,1) ----> clock 2 執行圖上的 (2,1) (1,2)
# 2 (2,0) (1,1) (0,2) ----> clock 3 執行圖上的 (3,1) (2,2) (1,3)
# 3 (2,1) (1,2) ----> clock 4 執行圖上的 (3,2) (2,3)
# 4 (2,2) ----> clock 5 執行圖上的 (3,3)

我們把 solve程式碼修改下,為了列印正確的index,這樣大家就可以更好的把程式碼和圖片對應起來了。

m=4 # m: number of micro-batches
n=3 # n: number of partitions
for k in range(m + n - 1):
print( [(k - j + 1 , j +1 ) for j in range(max(1 + k - m, 0), min(1 + k, n))] ) 列印是:
[(1, 1)] # 第 1 輪訓練計劃 & 資料
[(2, 1), (1, 2)] # 第 2 輪訓練計劃 & 資料
[(3, 1), (2, 2), (1, 3)] # 第 3 輪訓練計劃 & 資料
[(4, 1), (3, 2), (2, 3)] # 第 4 輪訓練計劃 & 資料
[(4, 2), (3, 3)] # 第 5 輪訓練計劃 & 資料
[(4, 3)] # 第 6 訓練計劃 & 資料

我們把流水線的圖再祭出來看看。

我們把上面的輸出按照流水線的圖繪製一下作為比對。

可以看到,前 4 個時鐘週期內,分別有 4 個 micro-batch 進入了 cuda:0,分別是(1,1) (2,1) (3,1) (4,1) 。然後按照 clock_cycles 演算法給出的順序,每次迭代(時鐘週期)內執行不同的schedule,經過了 6 個時鐘週期之後,完成了第一輪 forward 操作。這就形成了流水線。

流水線優勢在於,如果 number of micro-batches 配置的合適,那麼可以在每個時鐘週期內,最大程度的讓所有裝置都執行起來。與之對比,原生流水線每一時間只能讓一個裝置互活躍。

           +          +          +          +          +          +          +
| | | | | | |
| | | | | | |
cuda:0 | (1,1) | (2,1) | (3,1) | (4,1) | | |
| | | | | | |
| | | | | | |
| | | | | | |
| | | | | | |
cuda:1 | | (1,2) | (2,2) | (3,2) | (4,2) | |
| | | | | | |
| | | | | | |
| | | | | | |
| | | | | | |
cuda:2 | | | (1,3) | (2,3) | (3,3) | (4,3) |
| | | | | | |
| | | | | | |
| | | | | | |
| clock 1 | clock 2 | clock 3 | clock 4 | clock 5 | clock 6 |
+ + + + + + + +------------------------------------------------------------------------------> Time

具體資料batch的走向是:

         +             +            +             +            +            +             +
| | | | | | |
cuda:0 | (1,1) | (2,1) | (3,1) | (4,1) | | |
| + | + | + | + | | |
| | | | | | | | | | |
| | | | | | | +----------+ | |
| | | | | +-----------+ | | | |
| | | +------------+ | | | | | |
| | | | | | | | | | |
| +------------+ | | | | | | | |
| | | | | | | | | | |
| | | | v | v | v | |
| | v | | | | |
cuda:1 | | (1,2) | (2,2) | (3,2) | (4,2) | |
| | + | + | + | + | |
| | | | | | | | | | |
| | | | | | | | +-------------+ |
| | | | | | +----------+ | | |
| | | | +------------+ | | | | |
| | +-----------+ | | | | | | |
| | | | | v | v | v |
| | | v | | | |
cuda:2 | | | (1,3) | (2,3) | (3,3) | (4,3) |
| | | | | | |
| | | | | | |
| | | | | | |
| clock 1 | clock 2 | clock 3 | clock 4 | clock 5 | clock 6 |
+ + + + + + + +-----------------------------------------------------------------------------------> Time

2.4 使用

在 Pipeline 類之中,我們可以看到,就是按照時鐘週期來啟動計算,這樣在前向傳播之中,就按照這個序列,像水波紋一樣擴散。

    def run(self) -> None:
"""Runs pipeline parallelism. It modifies the given batches in place. """
batches = self.batches
partitions = self.partitions
devices = self.devices
skip_layout = self.skip_layout m = len(batches)
n = len(partitions) skip_trackers = [SkipTrackerThroughPotals(skip_layout) for _ in batches] with spawn_workers(devices) as (in_queues, out_queues):
for schedule in clock_cycles(m, n): # 這裡使用,給出了執行序列計劃,後續按照這個來執行
self.fence(schedule, skip_trackers) # 構建後向傳播依賴關係
self.compute(schedule, skip_trackers, in_queues, out_queues) # 進行計算

至此,前向傳播過程分析完畢,下一篇我們分析依賴關係。

0xFF 參考

Markdown公式用法大全

markdown中公式編輯教程

https://docs.nvidia.com/cuda/cuda-runtime-api/stream-sync-behavior.html#stream-sync-behavior

CUDA學習:基礎知識小結

CUDA隨筆之Stream的使用

NVIDIA解決方案架構師深度解析大規模引數語言模型Megatron-BERT

Accelerating Wide & Deep Recommender Inference on GPUs

HugeCTR: High-Performance Click-Through Rate Estimation Training

https://discuss.pytorch.org/t/how-to-prefetch-data-when-processing-with-gpu/548

https://github.com/NVIDIA/apex/

https://github.com/justheuristic/prefetch_generator

https://pytorch.org/tutorials/intermediate/model_parallel_turotial.html

https://pytorch.org/docs/stable/autograd.html

https://pytorch.org/docs/notes/cuda.html

https://zhuanlan.zhihu.com/p/61765561

https://pytorch.apachen.org/docs/1.7/64.html

https://zhidx.com/p/217999.html