[原始碼解析] 深度學習分散式訓練框架 horovod (21) --- 之如何恢復訓練

0x00 摘要

本文以 PyTorch on Horovod 為切入點,分析一下 Horovod 彈性訓練的恢復流程,具體涉及知識點有:

ElasticSampler與PyTorch 原生DistributedSampler 的區別,Horovod 彈性訓練如何恢復等。

本系列其他文章連結如下:

[原始碼解析] 深度學習分散式訓練框架 Horovod (1) --- 基礎知識

[原始碼解析] 深度學習分散式訓練框架 horovod (2) --- 從使用者角度切入

[原始碼解析] 深度學習分散式訓練框架 horovod (3) --- Horovodrun背後做了什麼

[原始碼解析] 深度學習分散式訓練框架 horovod (4) --- 網路基礎 & Driver

[原始碼解析] 深度學習分散式訓練框架 horovod (5) --- 融合框架

[原始碼解析] 深度學習分散式訓練框架 horovod (6) --- 後臺執行緒架構

[原始碼解析] 深度學習分散式訓練框架 horovod (7) --- DistributedOptimizer

[原始碼解析] 深度學習分散式訓練框架 horovod (8) --- on spark

[原始碼解析] 深度學習分散式訓練框架 horovod (9) --- 啟動 on spark

[原始碼解析] 深度學習分散式訓練框架 horovod (10) --- run on spark

[原始碼解析] 深度學習分散式訓練框架 horovod (11) --- on spark --- GLOO 方案

[原始碼解析] 深度學習分散式訓練框架 horovod (12) --- 彈性訓練總體架構

[原始碼解析] 深度學習分散式訓練框架 horovod (13) --- 彈性訓練之 Driver

[原始碼解析] 深度學習分散式訓練框架 horovod (14) --- 彈性訓練發現節點 & State

[原始碼解析] 深度學習分散式訓練框架 horovod (15) --- 廣播 & 通知

[原始碼解析] 深度學習分散式訓練框架 horovod (16) --- 彈性訓練之Worker生命週期

[原始碼解析] 深度學習分散式訓練框架 horovod (17) --- 彈性訓練之容錯

[原始碼解析] 深度學習分散式訓練框架 horovod (18) --- kubeflow tf-operator

[原始碼解析] 深度學習分散式訓練框架 horovod (17) --- 彈性訓練之容錯

[原始碼解析] 深度學習分散式訓練框架 horovod (18) --- kubeflow tf-operator

[原始碼解析] 深度學習分散式訓練框架 horovod (19) --- kubeflow MPI-operator

[原始碼解析] 深度學習分散式訓練框架 horovod (20) --- Elastic Training Operator

0x01 總論

本文緣起於一個兄弟的留言:

請問在彈性訓練中,如果節點數目發生變化,資料怎麼重新劃分呢?比如一個epoch還沒有進行完,這時添加了新節點,新資料重新劃分的話,當前記憶體中用舊資料訓練的模型還有效嗎?

我恰好在分析PyTorch分散式的時候也有類似疑問,所以就回頭再看看Horovod是如何實現的。

我們之前對於 Horovod 的分析和示例大多以 TensorFlow 為例。大家對各種框架如何在Horovod之中適配的總體邏輯和思路應該有了一個大致的認識,所以我們本部分主要看看一些PyTorch 相關的特殊之處。

使用PyTorch做切入的另外一個原因是:在恢復訓練這個流程上,PyTorch相關部分確實相對清晰明確。

在 horovod/torch/elastic/ 目錄下,有兩個檔案 :state.py 和 sampler.py。既然是彈性相關,所以我們先來看看其特殊之處。

0x02 Sampler

在 horovod/torch/elastic/sampler.py 之中,有一個 ElasticSampler 類,我們看看具體針對彈性做了哪些處理。

因為 ElasticSampler 類之中註明,它的實現非常類似DistributedSampler,也就是 PyTorch 原生的實現,所以我們要先看看 DistributedSampler

2.1 PyTorch Distributed Optimizer

2.1.1 定義

DistributedSampler程式碼位於:torch/distributed/optim/optimizer.py。

總結一下DistributedSampler的分配方法是:每段連續的 num_replicas 個數據被拆成一個一個,分給 num_replicas 個程序,這樣就達到了不重疊不交叉的目的,但也要注意的是:這樣每個程序拿到的資料是不連續的

__iter__ 程式碼的一個技術細節是 本worker如何遍歷?

indices = indices[self.rank:self.total_size:self.num_replicas]

這裡,num_replicas 實際就是rank的總數,起始位置是self.rank,結束位置是總資料長度,按照num_replicas(就是world size)作為步長來遞增,所以這裡每個worker就會嚴格返回自己rank對應的那部分資料序號

我們用一個例子來看看,比如:

a = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]
print(a[0:15:3])
print(a[1:15:3])
print(a[2:15:3])

得到:

[1, 4, 7, 10, 13]
[2, 5, 8, 11, 14]
[3, 6, 9, 12, 15]

具體程式碼如下:

class DistributedSampler(Sampler[T_co]):

    def __iter__(self) -> Iterator[T_co]:

        if self.shuffle: # 如果需要shuffle,則會基於epoch和seed進行處理
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
else: # 否則直接返回資料集長度序列
indices = list(range(len(self.dataset))) # type: ignore[arg-type] # 是否需要補齊資料
if not self.drop_last:
# add extra samples to make it evenly divisible
padding_size = self.total_size - len(indices)
if padding_size <= len(indices):
indices += indices[:padding_size]
else:
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
else:
# remove tail of data to make it evenly divisible.
indices = indices[:self.total_size]
assert len(indices) == self.total_size # subsample
# 依據自己的rank,依次返回自己的資料序號
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples return iter(indices) # 後續就使用這些indices來對資料進行提取 def __len__(self) -> int:
return self.num_samples def set_epoch(self, epoch: int) -> None:
r"""
Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
use a different random ordering for each epoch. Otherwise, the next iteration of this
sampler will yield the same ordering. Args:
epoch (int): Epoch number.
"""
self.epoch = epoch

2.1.2 問題點

DistributedSampler 如果直接用到 彈性訓練,是有一定問題的,讓我們分析一下,有幾個問題:

  • 如果使用者已經訓練了5輪,那麼就意味著已經使用了前面5個批次的資料。假設此時加入了新的worker節點,那麼就應該恢復訓練。那麼對於已經使用過的前面 5 個批次的資料,按說就不應該再次被用來訓練了。

    • 問題1: 恢復訓練之後,應該怎麼去除已經處理的資料index?
  • 如果加入或者減少節點,如果告訴 Sampler,我們需要更改提取規則,最起碼,num_replicas 需要被更新,以後按照新的 num_replicas 進行提取,比如原來5個節點,num_replicas = 5,現在6個節點,num_replicas 應該為 6。
    • 問題2: 恢復訓練之後,何時呼叫 __iter__以進行新的訓練?
    • 問題3: 恢復訓練之後,何時修改 num_replicas?

我們看看 DistributedSampler 就會發現,其__iter__之中,沒有任何儲存狀態的相關資訊。即如果重新開始訓練,依然會從全體資料中提取,而非從剩餘資料中提取。也沒有發現對後面兩個問題的解決辦法。

因此,很難利用 DistributedSampler進行彈性訓練,所以 Horovod 就使用 ElasticSampler 來解決這個問題。

2.2 ElasticSampler

2.2.1 定義

從註釋中我們可以看到,ElasticSampler 自稱與 DistributedSampler 非常類似。我們隨後針對兩個類程式碼比較可以看到,功能基本一致。

但是有兩個新加入的變數值得注意,即:

    self.processed_indices = set()
self.remaining_indices = []

定義如下:

import math
import random
import torch.utils.data.distributed
from horovod.torch.mpi_ops import rank, size class ElasticSampler(torch.utils.data.Sampler):
"""Sampler that partitions dataset across ranks and repartitions after reset events. Works similar to `DistributedSampler`, but with an optional capability to record
which dataset indices have been processed each batch. When tracked by a `TorchState`
object, the sampler will automatically repartition the unprocessed indices among the
new set of workers. In order to use this object successfully it is recommended that the user: 1. Include this object in the `TorchState`.
2. Call `record_batch` or `record_indices` after processing a set of samples.
3. Call `set_epoch` at the end of each epoch to clear the processed indices. Args:
dataset: Dataset used for sampling (assumed to be of constant size).
shuffle: If `True` (default), shuffle the indices.
seed: Random seed used to shuffle the sampler when `shuffle=True`.
This number should be identical across all ranks (default: 0).
"""
def __init__(self, dataset, shuffle=True, seed=0):
self.dataset = dataset
self.shuffle = shuffle
self.seed = seed self.epoch = 0
self.processed_indices = set() # 新加入的特色成員變數 self.num_replicas = 0
self.rank = 0
self.remaining_indices = [] # 新加入的特色成員變數
self.num_samples = 0
self.total_size = 0 self.reset()

2.2.2 彈性方案

具體彈性方案就圍繞之前提到的兩個變數來進行。

2.2.2.1 常規流程

我們回憶其註釋中提到的如何使用:

1. Include this object in the `TorchState`.
2. Call `record_batch` or `record_indices` after processing a set of samples.
3. Call `set_epoch` at the end of each epoch to clear the processed indices.

我們可以推匯出來其內在邏輯:

  • 進行本 epoch 訓練。

    • 當使用 __iter__ 獲取下一批次資料時候,self.indices = self.remaining_indices[:] 就會 只從未訓練的資料裡面提取
    • 每處理一個批次資料 之後,使用者使用 record_batch 或者 record_indices 來把已經訓練完的資料批次資訊儲存在 processed_indices。這樣就記錄了已經訓練完的資料
    • 如果產生了問題,或者有節點變更,則:
      • 會呼叫 reset 函式,reset 會把已經訓練完的資料 processed_indices 從總資料中移除,剩下的 self.remaining_indice就是沒有訓練的資料。
      • 恢復訓練, 只從未訓練的資料裡面提取
  • 當完成這個epoch 之後,會呼叫 set_epoch 來重置 processed_indices,也會呼叫 reset 方法進行清零。

具體功能程式碼是:

def set_epoch(self, epoch):
"""Sets the epoch for this sampler. When `shuffle=True`, this ensures all replicas use a different random ordering
for each epoch. Will clear and reset the `processed_indices` for the next epoch. It is important
that this is called at the end of the epoch (not the beginning) to ensure that
partially completed epochs do not reprocess samples. Args:
epoch: Epoch number.
"""
self.epoch = epoch
# 這裡也許有網友會有疑問,就是下面兩行程式碼應該交換一下次序。
# 但是實際上是沒有問題的,因為 reset 其實在異常處理時候的作用更大,在這裡其實就是個清零作用。
self.processed_indices = set()
self.reset() def record_batch(self, batch_idx, batch_size):
"""Record indices at batch `batch_idx` with length `batch_size` as processed."""
indices = set(self.get_indices(batch_idx, batch_size))
self.record_indices(indices) def record_indices(self, indices):
"""Record set `indices` as processed."""
self.processed_indices.update(indices) # 記錄已經訓練完的資料 def get_indices(self, batch_idx, batch_size):
"""Return list of indices at batch `batch_idx` with length `batch_size`."""
start_idx = batch_idx * batch_size
end_idx = min(start_idx + batch_size, len(self.indices))
return self.indices[start_idx:end_idx] def load_state_dict(self, state_dict):
self.epoch = state_dict['epoch']
self.processed_indices = state_dict['processed_indices'] # 從儲存的資料中提取
self.reset() def state_dict(self):
return dict( # 這裡是為了State.save 時候呼叫,就是模型儲存時候,需要儲存這兩個變數
epoch=self.epoch,
processed_indices=self.processed_indices
) def reset(self):
# size 程式碼位於horovod/torch/mpi_ops.py,是 size = _basics.size,可以認為就是 hvd.size()
self.num_replicas = size() # 重新配置有幾個worker
self.rank = rank() # Exclude any samples we have already processed this epoch
# 把已經訓練完的資料移除,得到的資料 remaining_indices 都是沒有經過訓練的
self.remaining_indices = [idx for idx in range(len(self.dataset))
if idx not in self.processed_indices] self.num_samples = int(math.ceil(len(self.remaining_indices) * 1.0 / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas def __iter__(self):
self.indices = self.remaining_indices[:] # 從剩餘資料中提取
if self.shuffle:
# Shuffle indices across workers deterministically in place
seed = self.seed + self.epoch
random.Random(seed).shuffle(self.indices) # add extra samples to make it evenly divisible
self.indices += self.indices[:(self.total_size - len(self.indices))]
assert len(self.indices) == self.total_size # subsample
# 本worker如何遍歷?起始index是self.rank,終止index是總資料長度,按照num_replicas來遞增
self.indices = self.indices[self.rank:self.total_size:self.num_replicas]
assert len(self.indices) == self.num_samples # 後續就按照上面的遍歷邏輯來遍歷
return iter(self.indices) def __len__(self):
return self.num_samples
2.2.2.2 異常處理

在 horovod/torch/elastic/state.py 之中,當重新訓練時候,會呼叫到 ElasticSampler 的 load_state_dict 方法。

而 load_state_dict 之中,會呼叫 reset,這樣就把已經訓練完的資料移除,得到的資料都是沒有經過訓練的。

所以重新訓練時候,本epoch之內,不會用已經訓練的資料再次重複訓練。

我們後續會詳細分析這個流程

2.2.1 如何使用

ElasticSampler 的使用如下,程式碼位於:examples/elastic/pytorch/pytorch_imagenet_resnet50_elastic.py。

本節我們主要介紹如何使用,就是正常使用/處理流程,後續會介紹異常處理,這裡省略部分次要程式碼。

2.2.1.1 主體程式碼

主體程式碼主要注意就是使用ElasticSampler分別配置了兩個彈性取樣器。

if __name__ == '__main__':
allreduce_batch_size = args.batch_size * args.batches_per_allreduce # Elastic Horovod: use ElasticSampler to partition data among workers.
train_dataset = datasets.ImageFolder()
train_sampler = hvd.elastic.ElasticSampler(train_dataset) # 配置了彈性取樣
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=allreduce_batch_size,
sampler=train_sampler,
**kwargs) val_dataset = datasets.ImageFolder()
val_sampler = hvd.elastic.ElasticSampler(val_dataset) # 配置了彈性取樣
val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=args.val_batch_size,
sampler=val_sampler,
**kwargs) # Set up standard ResNet-50 model.
model = models.resnet50() # Horovod: scale learning rate by the number of GPUs.
optimizer = optim.SGD(model.parameters(),
lr=(args.base_lr *
lr_scaler),
momentum=args.momentum, weight_decay=args.wd) # Horovod: wrap optimizer with DistributedOptimizer.
optimizer = hvd.DistributedOptimizer(
optimizer, named_parameters=model.named_parameters(),
compression=compression,
backward_passes_per_step=args.batches_per_allreduce,
op=hvd.Adasum if args.use_adasum else hvd.Average,
gradient_predivide_factor=args.gradient_predivide_factor) # Restore from a previous checkpoint, if initial_epoch is specified.
# Horovod: restore on the first worker which will broadcast weights to other workers. state = hvd.elastic.TorchState(model=model,
optimizer=optimizer,
train_sampler=train_sampler,
val_sampler=val_sampler,
epoch=resume_from_epoch,
batch=0) full_train(state)
2.2.1.2 訓練程式碼

以下程式碼是具體訓練程式碼。

def train(state):

    model.train()
epoch = state.epoch batch_offset = state.batch
with tqdm(total=len(train_loader),
desc='Train Epoch #{}'.format(epoch + 1),
disable=not verbose) as t: # 迴圈獲取資料,會間接呼叫到 ElasticSampler 的 __iter__ 方法來獲取資料 index
for idx, (data, target) in enumerate(train_loader):
# Elastic Horovod: update the current batch index this epoch
# and commit / check for host updates. Do not check hosts when
# we commit as it would be redundant.
state.batch = batch_idx = batch_offset + idx
if args.batches_per_commit > 0 and \
state.batch % args.batches_per_commit == 0:
state.commit()
elif args.batches_per_host_check > 0 and \
state.batch % args.batches_per_host_check == 0:
state.check_host_updates() adjust_learning_rate(epoch, batch_idx) optimizer.zero_grad() # Split data into sub-batches of size batch_size
for i in range(0, len(data), args.batch_size):
data_batch = data[i:i + args.batch_size]
target_batch = target[i:i + args.batch_size]
output = model(data_batch)
train_accuracy.update(accuracy(output, target_batch))
loss = F.cross_entropy(output, target_batch)
train_loss.update(loss)
# Average gradients among sub-batches
loss.div_(math.ceil(float(len(data)) / args.batch_size))
loss.backward() # Elastic Horovod: record which samples were processed this batch
# so we do not reprocess them if a reset event occurs
# 這裡會記錄已經完成的資料
state.train_sampler.record_batch(idx, allreduce_batch_size) # Gradient is applied across all ranks
optimizer.step() state.commit() def end_epoch(state):
state.epoch += 1
state.batch = 0
state.train_sampler.set_epoch(state.epoch) # 這裡會對剩餘資料資訊清零
state.commit() @hvd.elastic.run
def full_train(state):
while state.epoch < args.epochs:
train(state)
validate(state.epoch)
save_checkpoint(state.epoch)
end_epoch(state) # 這裡會對剩餘資料資訊清零

某一個epoch具體邏輯(正常處理)如下:

  1. 如果是最初執行,則呼叫reset進行初始化,其中會依據 dataset 長度構建一個 index list。用這個index list 減去 processed_indices ,就得到了本次epoch應該處理的資料 index,賦值給 remaining_indices,就是剩下來應該處理的資料index;
  2. __iter__ 函式中,呼叫 self.indices = self.remaining_indices[:] ,這樣 indices 就可以用來做迭代提取;
  3. 訓練函式中,呼叫 iter(indices) 進行迭代提取,然後呼叫 record_indices 把本次使用過的index 更新到 processed_indices 之中。processed_indices 就記錄了目前使用的所有index;
  4. epoch 結束之後,呼叫 set_epoch 進行重置,即給 processed_indices 清零,呼叫 reset 重置 remaining_indices;
              +---------------------------------------------------------------+
| ElasticSampler |
| |
+--------------------------------------------> + |
4 | set_epoch | | |
| | | |
| | 1 | reset |
| | | |
| | | |
| | v |
| | |
| | remaining_indices = dataset - processed_indices |
| | |
| | + |
| | | |
| | | |
| | 2 | __iter_ |
| | | |
| | | |
| | v |
| | indices = remaining_indices[:] |
| | + |
| | | |
| +---------------------------------------------------------------+
| |
| 3 |
| |
| v
| +--------------------------------------+------------------------------------+
| | train() train loop |
| | |
| | ----------------------------> iter(indices)+--------------------> |
| | ^ | |
| | | | |
| | step() backward() |
| | | +----------------------------------------+ | |
| | | |record_indices | | |
| | | | | | |
| | <-------------+ processed_indices.update(indices) +------+ v |
| | | | |
| | +----------------------------------------+ |
| | |
| +---------------------------------------+-----------------------------------+
| |
| |
+-----------------------------------------------+

0x03 儲存和定期檢查

3.1 定期儲存

Hovorod 建議使用者定週期性呼叫 state.commit() 來把狀態(state)備份到記憶體。

  • 定期備份非常有用。在某些worker發生意外錯誤時,定期備份可以避免因為狀態被損壞而在重新訓練時候無法恢復現場。比如,如果一個worker剛好在更新引數過程中突然出錯,此時部分梯度更新完畢,部分梯度可能只更新到一半,這個狀態是不可逆轉而又無法繼續。因此,當此狀態發生時,會丟擲一個 HorovodInternalError 異常,當 hvd.elastic.run 捕獲到這個異常後,會利用最新一次commit中恢復所有狀態
  • 因為commit狀態代價高昂(比如如引數量太大會導致耗時過長),所以需要在"每個batch的處理時間"與"如果出錯,訓練需要從多久前的狀態恢復"之間選取一個平衡點。比如,如果你每訓練10個batches就commit一次,你就把複製時間降低了10倍。但是當發生錯誤時,你需要回滾到10個batches前的狀態。
  • Elastic Horowod可以通過執行我們稱之為“優雅地移除worker”操作來避免這些回滾。如果driver程序發現主機已可用或標記為刪除,它將向所有workers推送一個通知。於是在下次呼叫state.commit()或更輕量級的state.check_host_updates()時,一個HostsUpdatedInterrupt異常將被丟擲。此異常的處理方式與“HorovodInternalError”類似,只是引數狀態不會還原到上次commit,而是從當前實時引數中恢復
  • 一般來說,如果你的硬體設施是可靠與穩定的,並且你的編排系統會在任務節點移除時提供足夠的告警,你就可低頻次呼叫 state.commit() 函式,同時只在每個batch結束時呼叫相對不耗時的 state.check_host_updates() 來檢查節點變更情況。

具體示例程式碼如下:

@hvd.elastic.run
def train(state):
for state.epoch in range(state.epoch, epochs):
for state.batch in range(state.batch, batches_per_epoch):
data, target = get_random_batch()
train_one_batch(data, target)
if state.batch % batches_per_commit == 0:
state.commit() # 定期儲存
state.batch = 0

3.2 異常處理

我們可以看到,HorovodInternalError 和 HostsUpdatedInterrupt 這兩個異常最大的區別:

  • HorovodInternalError 異常:當 hvd.elastic.run 捕獲到這個異常後,會利用最新一次commit中恢復所有狀態
  • HostsUpdatedInterrupt 異常:處理方式與“HorovodInternalError”類似,只是引數狀態不會還原到上次commit,而是從當前實時引數中恢復

之所以要強調這個,因為後面就要介紹如何做到不同恢復。

3.3 Commit

在使用者呼叫 State.commit 的時候,有兩個動作:一個是儲存狀態。一個是呼叫 check_host_updates 檢查更新。

class State(object):
"""State representation used for tracking in memory state across workers.""" def commit(self):
self.save()
self.check_host_updates()

這裡 save 就會呼叫到 State 的 save 操作,結合本文,就是下面要介紹的 TorchState 的 save 操作。

另外,check_host_updates 會丟擲HostsUpdatedInterrupt異常。HostsUpdatedInterrupt 異常裡面,是否需要 sync,從下面 check_host_updates 程式碼可以看出來,就是如果節點數目有變化了,就需要sync。HostUpdateResult.removed 數值為1,這裡其實可以改進,HostUpdateResult.removed 在目前這個情況之下,設定過細了。

class HostUpdateResult(IntFlag):
no_update = 0
removed = 1
added = 2
mixed = removed | added def check_host_updates(self):
"""Checks that a notification has been sent indicating that hosts can be added or will be removed. Raises a `HostsUpdatedInterrupt` if such a notification has been received.
"""
# Iterate through the update messages sent from the server. If the update timestamp
# is greater than the last update timestamp, then trigger a HostsUpdatedException.
last_updated_timestamp = prev_timestamp = self._last_updated_timestamp
all_update = HostUpdateResult.no_update
while not self._host_messages.empty():
timestamp, update = self._host_messages.get()
if timestamp > last_updated_timestamp:
last_updated_timestamp = timestamp
all_update |= update # In order to ensure all workers raise the exception at the same time, we need to sync
# the updated state across all the workers.
# TODO(travis): this should be a max allreduce to account for changes in rank 0
prev_timestamp, self._last_updated_timestamp, all_update = \
self._bcast_object((prev_timestamp, last_updated_timestamp, all_update)) # At this point, updated state is globally consistent across all ranks.
if self._last_updated_timestamp > prev_timestamp:
# 在這裡設定,其實含義就是:如果節點有變化,就設定為True,需要同步
raise HostsUpdatedInterrupt(all_update == HostUpdateResult.removed) # 丟擲異常

0x04 State

我們接下來介紹異常處理邏輯,具體圍繞著 State 來介紹。對於State,我們先回憶一下其在恢復訓練時候的邏輯。

4.1 恢復訓練

重新訓練時候,會丟擲兩種異常:

  • 如果是 ring allreduce 相關,就轉為丟擲異常 HorovodInternalError(e)。
  • 如果當驅動程序通過節點發現指令碼發現一個節點被標記為新增或者移除時,會丟擲異常 HostsUpdatedInterrupt。

然後會進行如下處理:

def run_fn(func, reset):
@functools.wraps(func)
def wrapper(state, *args, **kwargs):
notification_manager.init()
notification_manager.register_listener(state)
skip_sync = False try:
while True:
if not skip_sync:
state.sync() # 進行同步 try:
return func(state, *args, **kwargs)
except HorovodInternalError:
state.restore() # 進行恢復訓練
skip_sync = False # 需要同步
except HostsUpdatedInterrupt as e:
skip_sync = e.skip_sync # 記錄是否需要同步 reset()
state.on_reset() # 進行重啟
finally:
notification_manager.remove_listener(state)
return wrapper

邏輯如下:

+------------------------------------------------------------------------------+
| Worker |
| |
| +------------------------------------------------------------------------+ |
| | run_fn | |
| | +----------------------------------+ | |
| | | while True: | | |
| | | | | |
| | v | | |
| | | | |
| | state.sync() | | |
| | + | | |
| | | | | |
| | | | | |
| | v | | |
| | +------------------+---------------+ | | |
| | | train | | | |
| | | | | | |
| | | optimizer.apply_gradients +---------+ | | |
| | | | | | | |
| | +-------+ state.commit() | | | |
| | | | | | | | |
| | | +----------------------------------+ | | | |
| | | | | | |
| | v v | | |
| | HostsUpdatedInterrupt HorovodInternalError | | |
| | + | | |
| | + | | | |
| | | | | | |
| | | v | | |
| | | state.restore() | | |
| | | + | | |
| | | | | | |
| | +------------------+ <------------------+ | | |
| | | | | | |
| | | | | | |
| | v v | | |
| | reset() | | |
| | | | |
| | state.on_reset() | | |
| | | | |
| | + | | |
| | | | | |
| | +-----------------------------------> | |
| | | |
| +------------------------------------------------------------------------+ |
| |
+------------------------------------------------------------------------------+

因為這裡涉及了大量的state操作,所以我們接下來要看看 TorchState:

4.2 TorchState

首先,我們要看看 TorchState 如何使用。當呼叫時候,使用如下方法來生成一個TorchState:

    state = hvd.elastic.TorchState(model, optimizer, batch=0, epoch=0)
state.register_reset_callbacks([on_state_reset]) # 註冊使用者定義的方法 on_state_reset
train(state)

其次,我們看看 TorchState 的定義,這裡的 sync,restore,reset方法就在恢復訓練中被呼叫。

在初始化函式 __init__ 之中,會設定 handler,以我們的呼叫為例,就是 train_sampler,val_sampler這兩個對應的sampler會配置對應的handler,即SamplerStateHandler。

TorchState 繼承了 ObjectState,ObjectState 繼承了 State,所以前面提到的 commit 程式碼中的 self.save(),就會呼叫到TorchState.save,而這裡又會呼叫到 SamplerStateHandler.save

class TorchState(ObjectState):
"""State representation of a PyTorch training process. Multiple models and optimizers are supported by providing them as
kwargs. During initialization, `TorchState` will assign attributes
for every keyword argument, and handle its state synchronization. Args:
model: Optional PyTorch model.
optimizer: Optional PyTorch optimizer.
kwargs: Attributes sync, will be exposed as attributes of the object. If a handler exists
for the attribute type, it will be used to sync the object, otherwise it will be
handled an ordinary Python object.
"""
def __init__(self, model=None, optimizer=None, **kwargs):
kwargs.update(dict(model=model, optimizer=optimizer))
# 這裡會設定 handler,以我們的呼叫為例,就是train_sampler,val_sampler這兩個對應的sampler會配置對應的handler
self._handlers, kwargs = _get_handlers(kwargs)
for name, handler in self._handlers.items():
setattr(self, name, handler.value)
super(TorchState, self).__init__(bcast_object=broadcast_object,
get_rank=rank,
**kwargs) def save(self):
for handler in self._handlers.values():
handler.save() # 呼叫到save,針對我們,就是呼叫到了SamplerStateHandler的save
super(TorchState, self).save() def restore(self):
# 會進行恢復狀態
for handler in self._handlers.values():
handler.restore() # 這裡會呼叫到sampler的restore方法。
super(TorchState, self).restore() def sync(self):
# 會進行同步狀態
for handler in self._handlers.values():
handler.sync() # 這裡會呼叫到sampler的sync方法。
super(TorchState, self).sync() def __setattr__(self, name, value):
if hasattr(self, name) and name in self._handlers:
self._handlers[name].set_value(value)
super().__setattr__(name, value)

基類程式碼中有:

class State(object):

    def on_reset(self):
self._host_messages = queue.Queue()
self.reset() # 呼叫到reset
for callback in self._reset_callbacks:
callback()

4.3 設定 handler

上節中,我們可以看到,無論是reset,還是restore,都會呼叫到 _handlers 來進行處理,所以我們需要進一步分析。

首先就是如何設定handler。具體參見如下程式碼,主要是通過一個全域性配置 _handler_registry 來指定哪個 handler 處理哪種型別例項,比如這裡有 (ElasticSampler, SamplerStateHandler),就代表著 SamplerStateHandler 是用來處理 ElasticSampler的 handler。

_handler_registry = [
(torch.nn.Module, ModelStateHandler),
(torch.optim.Optimizer, OptimizerStateHandler),
(ElasticSampler, SamplerStateHandler), # SamplerStateHandler 是用來處理 ElasticSampler的
] def get_handler_registry():
return _handler_registry def set_handler_registry(registry):
global _handler_registry
_handler_registry = registry def _get_handler(v):
# 依據我們的樣例程式碼,v是 train_sampler,而 train_sampler,val_sampler就是 ElasticSampler 的例項,所以得到 handler_type是 ElasticSampler,則會構建一個 SamplerStateHandler 並且返回
for handler_type, handler_cls in _handler_registry:
if isinstance(v, handler_type):
return handler_cls(v) # 呼叫 SamplerStateHandler(train_sampler) 生成例項
return None def _get_handlers(kwargs):
handlers = {}
remainder = {}
# 這裡k,v就是 train_sampler=train_sampler,所以 k 是 "train_sampler", v是例項 train_sampler
for k, v in kwargs.items():
handler = _get_handler(v)
if handler:
handlers[k] = handler
else:
remainder[k] = v
return handlers, remainder

4.4 SamplerStateHandler

既然知道了 ElasticSampler 由 SamplerStaeHandler 處理,就來分析一下 SamplerStateHandler。

初始化之後,self.value 就是 sampler,針對我們之前的分析,就是ElasticSampler

SamplerStateHandler 具體程式碼是,這裡需要注意的是:初始化時候,會把ElasticSampler的狀態儲存起來,以後如果出錯,會用此來恢復。

同時,save 也會被呼叫,用來恢復,我們馬上就會分析。

class SamplerStateHandler(StateHandler):
def __init__(self, sampler):
super().__init__(sampler)
# 這裡會儲存 ElasticSampler 的屬性和資料
self._saved_sampler_state = copy.deepcopy(self.value.state_dict()) def save(self):
# 儲存 ElasticSampler 的屬性和資料
self._saved_sampler_state = copy.deepcopy(self.value.state_dict()) def restore(self):
# load_state_dict 會用__init__ 之中儲存的原始資料來恢復,最終會呼叫到 ElasticSampler.reset 方法
self.value.load_state_dict(self._saved_sampler_state) def sync(self):
# 1)Get the set of processed indices from all workers
world_processed_indices = _union(allgather_object(self.value.processed_indices)) # 2) Replace local processed indices with global indices
state_dict = self.value.state_dict() # 這裡會呼叫到 ElasticSampler 的 state_dict 方法
state_dict['processed_indices'] = world_processed_indices # 3) Broadcast and load the state to make sure we're all in sync
# 注意,這裡的 load_state_dict 最終也會呼叫一次 reset
self.value.load_state_dict(broadcast_object(state_dict))

SamplerStateHandler 的 基類是:

class StateHandler(object):
def __init__(self, value):
self.value = value def save(self):
raise NotImplementedError() def restore(self):
raise NotImplementedError() def sync(self):
raise NotImplementedError() def set_value(self, value):
self.value = value
self.save()

4.5 儲存

我們拓展一下save相關操作序列。

TorchState 繼承了 ObjectState,ObjectState 繼承了 State,所以:

  1. 前面提到的 commit 程式碼中的 self.save(),就會呼叫到TorchState.save。
  2. 而TorchState.save又會呼叫到 SamplerStateHandler.save。
  3. SamplerStateHandler.save 會儲存 ElasticSampler 的屬性和資料,就是儲存了 ElasticSampler 的 epoch 和 processed_indices。

這樣,在定期 commit 的時候,就定期儲存了模型的狀態和 ElasticSampler 的狀態,這些會在恢復訓練中用到。具體下圖所示:

               +---------------------------+
| TorchState |
| |
| commit |
| + |
| | |
| | 1 |
| | |
| v |
| save |
| | |
| | |
+---------------------------+
|
| 2
|
|
+-----------------------------------------------------------------+
|SamplerStateHandler | |
| | |
| | |
| | |
| | |
| def save(self): v |
| |
| _saved_sampler_state = copy.deepcopy( value.state_dict() ) |
| + |
| | |
+-----------------------------------------------------------------+
|
|
| 3
|
|
+------------------------------------------+
| ElasticSampler | |
| | |
| | |
| | |
| def state_dict(self): | |
| return dict( v |
| self.epoch, |
| self.processed_indices |
| ) |
| |
+------------------------------------------+

只看靜態定義,還是很難理解,需要分析動態流程。因為有兩種異常,所以我們分開剖析

回憶一下兩個異常最大的區別:

  • HorovodInternalError 異常:當 hvd.elastic.run 捕獲到這個異常後,會利用最新一次commit中恢復所有狀態
  • HostsUpdatedInterrupt 異常:處理方式與“HorovodInternalError”類似,只是引數狀態不會還原到上次commit,而是從當前實時引數中恢復

4.6 HostsUpdatedInterrupt

如果當驅動程序通過節點發現指令碼發現一個節點被標記為新增或者移除時,會丟擲異常 HostsUpdatedInterrupt。此時不是關鍵異常,因此可以繼續訓練本epoch,只是從後續訓練資料中,移除本epoch已經處理的資料。因此可以做到 引數狀態不會還原到上次commit,而是從當前實時引數中恢復

下面程式碼之中,我們只保留 HostsUpdatedInterrupt 相關程式碼。

def run_fn(func, reset):
@functools.wraps(func)
def wrapper(state, *args, **kwargs):
notification_manager.init()
notification_manager.register_listener(state)
skip_sync = False try:
while True:
if not skip_sync:
state.sync() # 3) 進行同步 try:
return func(state, *args, **kwargs) # 這裡會出錯,而且重新訓練也是來到這裡
except HostsUpdatedInterrupt as e:
# 1) 進行異常處理
skip_sync = e.skip_sync # 2.1) 記錄是否需要同步 reset() # 2)這裡會呼叫_basics.init 重新初始化 horovod,間接設定了ElasticSampler之中的 num_replicas
state.on_reset() # 進行重啟
finally:
notification_manager.remove_listener(state)
return wrapper

發生異常之後,

  • 1)HostsUpdatedInterrupt 表示本 epoch 需要繼續訓練,所以進行異常處理,其中只是會:

    • 1.1) 記錄本異常處理是否需要同步 :skip_sync = e.skip_sync。
  • 2)這個步驟主要是重啟 hvd,對worker數目進行更改。具體是呼叫 State 自身的 reset() 方法(程式碼位於horovod/torch/elastic/__init__.py),其中會:
    • 2.1) 呼叫 shutdown() 來結束本次任務。
    • 2.2) 呼叫 init(),從而呼叫_basics.init,最終重新建立 MPI 相關 context,所以 hvd.size() 就根據最新的worker數目進行了更改。後續 ElasticSampler.__iter__ 之中會相應修改num_replicas。
  • 3)這個步驟是把已經訓練完的資料移除,得到的資料都是沒有經過訓練的。如果需要同步,則會呼叫 state.sync() ,其會呼叫 SamplerStateHandler.sync 方法,其內部會:
    • 3.1) SamplerStateHandler會利用集合通訊從所有worker中收集processed_indices,賦予給 world_processed_indices,這就是所有workers 已經處理過的資料 index
    • 3.2) 呼叫 ElasticSampler.state_dict方法,得到本地 ElasticSampler.epoch 和 ElasticSampler.processed_indices 的引用。然後將 world_processed_indices 賦值給 state_dict['processed_indices'],這樣,本地 ElasticSampler.processed_indices 就是所有workers 已經處理過的資料 index
    • 3.3) self.value.load_state_dict(broadcast_object(state_dict)) 有兩步操作:
      • 廣播,這樣在同步之後,所有worker都有同樣的 state_dict['processed_indices'] 資料了。
      • load_state_dict 會再呼叫一次 ElasticSampler.reset此次 reset 會更改 num_replicas,也會從總資料中去除processed_indices,得到新的 remaining_indices, 從而 後續 __iter__ 之中,就會相應對提取index 的策略進行相應更改
  • 4)所以這樣就把已經訓練完的資料移除,所以得到的 remaining_indices 資料都是沒有經過訓練的。所以重新訓練時候,本epoch之內,不會用已經訓練的資料再次重複訓練,而是從當前實時引數中恢復。
    • 重新訓練會呼叫 return func(state, *args, **kwargs) 進行訓練,這裡會處理 ElasticSampler.__iter__
    • 當使用 __iter__ 獲取下一批次資料時候,self.indices = self.remaining_indices[:] 就會 只從未訓練的資料裡面提取

具體邏輯如下:

+-----------------------------------------------------------------------------------------------------------------------+
| Worker |
| |
| +-----------------------------------------------------------------------------------------------------------------+ |
| | run_fn | |
| | +-----------------------------------------------------------------------------+ | |
| | | while True: | | |
| | | | | |
| | v 3) | | |
| | state.sync() +------------------------------------------+----------------------+ | | |
| | | | | | |
| | + | | | | |
| | | | | | | |
| | | | | | | |
| | v | | | | |
| | +------------------+---------------+ 3.1) | 3.2) | | | |
| | | train | | | | | |
| | | | | | | | |
| | | optimizer.apply_gradients +---------+ | | | | |
| | | + | v | | | |
| | +-------+ state.commit() | | | | |
| | | | + | ElasticSampler.load_state_dict | | | |
| | | +----------------------------------+ | + | | | |
| | | | | | | | |
| | v v | | | | |
| | HostsUpdatedInterrupt HorovodInternalError v | | | |
| | + ElasticSampler.reset | | | |
| | + | + | | | |
| | | | | | | | |
| | | 1) v | | | | |
| | | state.restore() v | | | |
| | | + +-----------+-----------------+ | | | |
| | | | | ElasticSampler | | | | |
| | +------------------+ <------------------+ | | | | | |
| | | | | remaining_indices | | | | |
| | | | | | | | | |
| | v v | num_samples | | | | |
| | reset() | | | | | |
| | 2) | total_size | | | | |
| | state.on_reset() | | | | | |
| | | epoch | | | | |
| | + | | | | | |
| | | | processed_indices | | | | |
| | | | | | | | |
| | | | state_dict <-------------+ | | |
| | | | | | | |
| | | +-----------------------------+ | | |
| | | | | |
| | +------------------------------------------------------------------------------^ | |
| | | |
| +-----------------------------------------------------------------------------------------------------------------+ |
| |
+-----------------------------------------------------------------------------------------------------------------------+

手機如下:

4.7 HorovodInternalError

如果是 ring allreduce 相關,就轉為丟擲異常 HorovodInternalError(e)。HorovodInternalError 是關鍵異常,此時本 epoch 現有狀態其實意義不大,應該利用最新一次commit中恢復所有狀態

下面程式碼之中,我們只保留 HorovodInternalError 相關程式碼。

def run_fn(func, reset):
@functools.wraps(func)
def wrapper(state, *args, **kwargs):
notification_manager.init()
notification_manager.register_listener(state)
skip_sync = False try:
while True:
if not skip_sync:
state.sync() # 3) 進行同步
try:
return func(state, *args, **kwargs) # 這裡會出錯,而且重新訓練也是來到這裡
except HorovodInternalError:
# 1) 進行異常處理
state.restore() #1.1) 進行恢復訓練,這裡是和 HostsUpdatedInterrupt 的不同之處
skip_sync = False # 1.2) 記錄需要同步 reset() # 2)這裡會呼叫_basics.init 重新初始化 horovod,間接設定了ElasticSampler之中的 num_replicas
state.on_reset() # 進行重啟
finally:
notification_manager.remove_listener(state)
return wrapper

HorovodInternalError 和 HostsUpdatedInterrupt 的程式碼路徑幾乎一樣,只是多了一步 state.restore() 。

這裡為啥也要走檢視節點變化這個程式碼路徑呢?因為Horovod是定期檢查節點變化,所以可能產生HorovodInternalError時候,也有節點變化了,只是還沒有發現而已,所以可以一併處理了。

具體邏輯為:

  • 1)HorovodInternalError 表示本 epoch 需要恢復訓練,所以先進行異常處理:

    • 1.1)state.restore() 會呼叫 SamplerStateHandler.restore(這裡是與HostsUpdatedInterrupt處理差異之處)。

      • 進而呼叫 ElasticSampler.load_state_dict方法,會用在SamplerStateHandler.__init__ 或者SamplerStateHandler.save 之中原始儲存的資料來恢復 ElasticSampler。儲存的資料就是 processed_indices 和 epoch。
      • ElasticSampler.load_state_dict方法 進而會呼叫 ElasticSampler.reset方法,使用 processed_indices 把已經訓練完的資料移除,最新得到的 remaining_indices 資料都是沒有經過訓練的(針對上次儲存的 processed_indices 來說)。
    • 1.2) 記錄本異常處理需要同步 : skip_sync = False。
  • 2)這個步驟主要是重啟 hvd。呼叫 State 自身的 reset() 方法(程式碼位於horovod/torch/elastic/__init__.py),其中會:
    • 2.1) 呼叫 shutdown() 來結束本次任務。
    • 2.2) 呼叫 init(),從而呼叫_basics.init,最終重新建立 MPI 相關 context。
  • 3)這個步驟是把已經訓練完的資料移除,得到的資料都是沒有經過訓練的。因為這裡需要同步,所以會呼叫 state.sync() ,其會呼叫 SamplerStateHandler.sync 方法,其內部會:
    • 3.1) SamplerStateHandler會利用集合通訊從所有worker中收集processed_indices,賦予給 world_processed_indices,這就是所有workers 已經處理過的資料 index。需要注意的是:因為是使用在__init__ 或者 save之中原始儲存的資料來恢復,所以其實這一步是恢復到上次commit狀態
    • 3.2) 呼叫 ElasticSampler.state_dict方法,得到本地 ElasticSampler.epoch 和 ElasticSampler.processed_indices 的引用。然後將 world_processed_indices 賦值給 state_dict['processed_indices'],這樣,本地 ElasticSampler.processed_indices 就是所有workers 已經處理過的資料 index
    • 3.3) 這裡 self.value.load_state_dict(broadcast_object(state_dict)) 有兩步操作:
      • 廣播,這樣在同步之後,所有worker都有同樣的 state_dict['processed_indices'] 資料了。
      • load_state_dict 會再呼叫一次 ElasticSampler.reset此次 reset 會更改 num_replicas,也會從總資料中去除processed_indices,得到新的 remaining_indices, 從而 後續 __iter__ 之中,就會相應對提取index 的策略進行相應更改
  • 4)這樣就是恢復到epoch 上次 commit 的狀態進行訓練
    • 重新訓練會呼叫 return func(state, *args, **kwargs) 進行訓練,這裡會處理 ElasticSampler.__iter__
    • 當使用 __iter__ 獲取下一批次資料時候,self.indices = self.remaining_indices[:] 就會 只從未訓練的資料裡面提取

具體邏輯如下圖:

+--------------------------------------------------------------------------------------------------------------------+
| Worker |
| |
| +--------------------------------------------------------------------------------------------------------------+ |
| | run_fn | |
| | +-----------------------------------------------------------------------------+ | |
| | | while True: | | |
| | | | | |
| | v 3 | | |
| | state.sync() +-----------------------------------------------------------------+ | | |
| | | | | |
| | + +--------------+ | | | |
| | | | | | | | |
| | | | | | | | |
| | v | v | | | |
| | +------------------+---------------+ | | | | |
| | | train | | SamplerStateHandler.restore | | | |
| | | | | + | | | |
| | | optimizer.apply_gradients +---------+ | | | | | |
| | | + | | | | | | |
| | +-------+ state.commit() | | v | | | |
| | | | + | | ElasticSampler.load_state_dict | | | |
| | | +----------------------------------+ | | + | | | |
| | | | | | | | | |
| | v v | | | | | |
| | HostsUpdatedInterrupt HorovodInternalError | v | | | |
| | + | ElasticSampler.reset | | | |
| | + | | + | | | |
| | | | | | | | | |
| | | v 1 | | | | | |
| | | state.restore()+-----+ v | | | |
| | | + +-----------+-----------------+ | | | |
| | | | | ElasticSampler | | | | |
| | +------------------+ <------------------+ | | | | | |
| | | | | remaining_indices | | | | |
| | | | | | | | | |
| | v v | num_samples | | | | |
| | reset() 2 | | | | | |
| | | total_size | | | | |
| | state.on_reset() | | | | | |
| | | epoch | | | | |
| | + | | | | | |
| | | | processed_indices | | | | |
| | | | | | | | |
| | | | state_dict <-------------+ | | |
| | | | | | | |
| | | +-----------------------------+ | | |
| | | | | |
| | +------------------------------------------------------------------------------^ | |
| | | |
| +--------------------------------------------------------------------------------------------------------------+ |
| |
+--------------------------------------------------------------------------------------------------------------------+

手機如下:

4.8 ElasticSampler.__iter__

到目前為止,我們還有一個問題沒有仔細分析,就是何時呼叫 ElasticSampler.__iter__

我們仔細梳理一下:

以下是彈性訓練總體邏輯:

def run_fn(func, reset):
@functools.wraps(func)
def wrapper(state, *args, **kwargs):
notification_manager.init()
notification_manager.register_listener(state)
skip_sync = False try:
while True:
if not skip_sync:
state.sync() try:
# 如果出錯恢復,這裡會繼續呼叫 func 進行訓練
return func(state, *args, **kwargs)
except HorovodInternalError:
state.restore()
skip_sync = False
except HostsUpdatedInterrupt as e:
skip_sync = e.skip_sync reset()
state.on_reset()
finally:
notification_manager.remove_listener(state)
return wrapper

彈性邏輯使用註解來封裝了full_train,所以 func 就是 full_train。

@hvd.elastic.run
def full_train(state):
while state.epoch < args.epochs:
train(state)
validate(state.epoch)
save_checkpoint(state.epoch)
end_epoch(state)

我們看看 train 的主要程式碼:

def train(state):
model.train()
epoch = state.epoch with tqdm(...) as t:
# 這裡 enumerate 之中會呼叫到 ElasticSampler.__iter__
for idx, (data, target) in enumerate(train_loader): # Split data into sub-batches of size batch_size
for i in range(0, len(data), args.batch_size):
data_batch = data[i:i + args.batch_size]
target_batch = target[i:i + args.batch_size]
output = model(data_batch)
train_accuracy.update(accuracy(output, target_batch))
loss = F.cross_entropy(output, target_batch)
train_loss.update(loss)
# Average gradients among sub-batches
loss.div_(math.ceil(float(len(data)) / args.batch_size))
loss.backward() # Elastic Horovod: record which samples were processed this batch
# so we do not reprocess them if a reset event occurs
state.train_sampler.record_batch(idx, allreduce_batch_size) # Gradient is applied across all ranks
optimizer.step() state.commit()

所以我們可以理出來總體邏輯:

當出錯恢復時候,train 會再次被呼叫,呼叫時候就會使用 enumerate(train_loader)呼叫到 ElasticSampler.__iter__

num_replicas 在之前 reset 時候已經被設定,所以此時就是根據新的 world size 和 remaining_indices 重新確定提取資料的策略。

def __iter__(self):
self.indices = self.remaining_indices[:] # 從剩餘資料中提取
if self.shuffle:
# Shuffle indices across workers deterministically in place
seed = self.seed + self.epoch
random.Random(seed).shuffle(self.indices) # add extra samples to make it evenly divisible
self.indices += self.indices[:(self.total_size - len(self.indices))]
assert len(self.indices) == self.total_size # subsample
# 本worker如何遍歷?起始index是self.rank,終止index是總資料長度,按照 num_replicas 來遞增
self.indices = self.indices[self.rank:self.total_size:self.num_replicas]
assert len(self.indices) == self.num_samples # 後續就按照上面的遍歷邏輯來遍歷
return iter(self.indices)

具體邏輯如下,其中

1)在 reset 之中設定了num_replicas。

2)在 ElasticSampler.__iter__ 之中根據新的 world size 和 remaining_indices 重新確定提取資料的策略。

+----------------------------------------------------------------------------------------------------------------+
| Worker |
| |
| +----------------------------------------------------------------------------------------------------------+ |
| | run_fn | |
| | +----------------------------------+ | |
| | | while True: | | |
| | | | | |
| | v | | |
| | | | |
| | state.sync() | | |
| | + | | |
| | | | | |
| | | | | |
| | v | | |
| | +--------------------------------+ +------------------+---------------+ | | |
| | | ElasticSampler | | train | | | |
| | | +---------------------------+ | | optimizer.apply_gradients +---------+ | | |
| | | | __iter__ | | 2) | | | | | |
| | | | | | <------------+ enumerate(train_loader) | | | | |
| | | | | | | | | | | |
| | | | remaining_indices | | +-------+ state.commit() | | | | |
| | | | | | | | | | | | |
| | | | | | | +----------------------------------+ | | | |
| | | | num_replicas | | v v | | |
| | | | | | HostsUpdatedInterrupt HorovodInternalError | | |
| | | | ^ | | + | | |
| | | | | | | + | | | |
| | | +---------------------------+ | | | | | |
| | +--------------------------------+ | v | | |
| | | | state.restore() | | |
| | | | + | | |
| | | | | | | |
| | | +------------------+ <------------------+ | | |
| | | | | | | |
| | | | | | | |
| | | 1) v v | | |
| | +----------------------------------------+ reset() | | |
| | | | |
| | state.on_reset() | | |
| | | | |
| | + | | |
| | | | | |
| | +-----------------------------------> | |
| | | |
| +----------------------------------------------------------------------------------------------------------+ |
| |
+----------------------------------------------------------------------------------------------------------------+

手機如下:

至此,彈性訓練如何恢復就分析完畢,以後可能結合 Pytorch 分散式 optimizer 來繼續分析。

0xFF 參考

PyTorch 中文手冊(2)-自動求導

pytorch中優化器optimizer.param_groups

PyTorch學習筆記6--案例2:PyTorch神經網路(MNIST CNN)

https://github.com/chenyuntc/pytorch-book