[原始碼解析] PyTorch 分散式(2) --- 資料載入之DataLoader

0x00 摘要

為了更好的介紹引數伺服器Paracel的資料載入,我們臨時插入兩篇PyTorch的資料載入,主要是從分散式的角度進行切入。本文只算是開胃甜點,後續會有專門系列分析PyTorch分散式。

引數伺服器系列其他文章如下:

[原始碼解析] 機器學習引數伺服器ps-lite 之(1) ----- PostOffice

[原始碼解析] 機器學習引數伺服器ps-lite(2) ----- 通訊模組Van

[原始碼解析] 機器學習引數伺服器ps-lite 之(3) ----- 代理人Customer

[原始碼解析]機器學習引數伺服器ps-lite(4) ----- 應用節點實現

[原始碼解析] 機器學習引數伺服器 Paracel (1)-----總體架構

[原始碼解析] 機器學習引數伺服器 Paracel (2)--------SSP控制協議實現

[原始碼解析] PyTorch 分散式(1) --- 資料載入之DistributedSampler

0x01 前情回顧

關於資料載入,上回書我們說到了 DistributedSampler,本文接下來就進行 DataLoader的分析。

為了更好說明,我們首先給出上文的流水線圖,本文會對這個圖進行細化。

                    +------------+
+--------+ | |
| | | Process 1 |
| Data 1 +--------> | +------+
| | | Load Data | |
+--------+ | | |
+------------+ |
|
|
|
+------------+ | +-----------------------------------+
+--------+ | | | | |
| | | Process 2 | +------> | Pin-memory process |
| Data 2 +--------> | | | |
| | | Load Data +-------------> | |
+--------+ | | | Transfer to Pinned Memory |
+------------+ +-----> | |
| | |
| +-----------------------------------+
|
+--------+ +------------+ |
| | | | |
| Data 3 +--------> | Process 3 +-------+
| | | |
+--------+ | Load Data |
| |
+------------+

其次,我們再看看資料載入總體邏輯,具體如下圖,簡要說就是:

  1. DataSet 把資料集數目發給DistributedSampler。
  2. Sampler 按照某種規則生成資料indices併發送給DataLoader。
  3. DataLoader 依據indices來從DataSet之中載入資料(其內部的DataLoaderIter物件負責協調單程序/多程序載入Dataset)。
  4. DataLoader 把資料發給模型,進行訓練。
+------------------------+                     +-----------+
|DistributedSampler | |DataLoader |
| | 2 indices | |
| Some strategy +-------------------> | |
| | | |
|-------------+----------| | |
^ | | 4 data +-------+
| | -------------->+ train |
1 | length | | +-------+
| | |
+-------------+----------+ | |
|DataSet | | |
| +---------+ | 3 Load | |
| | Data +-------------------------> | |
| +---------+ | | |
| | | |
+------------------------+ +-----------+

接下來,我們就正式進入 DataLoader。

0x02 DataLoader

DataLoader的作用是:結合Dataset和Sampler之後,在資料集上提供了一個迭代器

可以這麼理解:

DataSet 是原始資料,Sampler 提供瞭如何切分資料的策略(或者說是提供了切分資料的維度),DataLoader就是依據策略來具體打工幹活的,其中單程序載入就是一個人幹活,多程序載入就是多拉幾個人一起幹活

2.1 初始化

初始化的主要引數如下:

  • dataset (Dataset) :所載入的資料集。
  • batch_size (int, optional) :每個批次載入多少個樣本。
  • shuffle (bool, optional) :如果為 True,則每個epoch 都會再打亂資料。
  • sampler (Sampler or Iterable, optional) :定義瞭如何從樣本取樣的策略。可以是任何實現了 __len__的迭代器。
  • batch_sampler (Sampler or Iterable, optional) :與sampler類似,但是每次返回一個批次的資料索引。
  • num_workers (int, optional) :資料載入的子程序數目。如果是 0,表示從主程序載入資料。
  • collate_fn (callable, optional):從一個小批次( mini-batch)張量中合併出一個樣本列表。當從 map-style 資料集做批量載入時候使用。
  • pin_memory (bool, optional) : 如果為true,則在返回張量之前把張量拷貝到CUDA固定記憶體之中。
  • drop_last (bool, optional) :當資料集不能被均勻分割時,如果為true,丟掉最後一個不完整的批次。如果為False,那麼最後一個批次的資料較小。
  • timeout (numeric, optional): 如果是整數,則是worker收集批次資料的超時值。
  • worker_init_fn (callable, optional):如果非空,則會在seeding和資料載入之前被每個子程序呼叫,以Iworker id ([0, num_workers - 1])作為輸入引數。
  • generator (torch.Generator, optional):如果非空,則被RandomSampler 用來產生隨機索引,也被多程序用來產生 base_seed
  • prefetch_factor (int, optional, keyword-only arg):每個 worker 提前載入 的 sample 數量。
  • persistent_workers (bool, optional):如果為 True, 則在消費一次之後,data loader也 不會關掉worker程序。這允許workerDataset例項維持活動狀態。

具體初始化程式碼如下,主要就是各種設定,為了更好的說明,去除了異常處理程式碼:

class DataLoader(Generic[T_co]):

    dataset: Dataset[T_co]
batch_size: Optional[int]
num_workers: int
pin_memory: bool
drop_last: bool
timeout: float
sampler: Sampler
prefetch_factor: int
_iterator : Optional['_BaseDataLoaderIter']
__initialized = False def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,
shuffle: bool = False, sampler: Optional[Sampler[int]] = None,
batch_sampler: Optional[Sampler[Sequence[int]]] = None,
num_workers: int = 0, collate_fn: Optional[_collate_fn_t] = None,
pin_memory: bool = False, drop_last: bool = False,
timeout: float = 0, worker_init_fn: Optional[_worker_init_fn_t] = None,
multiprocessing_context=None, generator=None,
*, prefetch_factor: int = 2,
persistent_workers: bool = False):
torch._C._log_api_usage_once("python.data_loader") self.dataset = dataset
self.num_workers = num_workers
self.prefetch_factor = prefetch_factor
self.pin_memory = pin_memory
self.timeout = timeout
self.worker_init_fn = worker_init_fn
self.multiprocessing_context = multiprocessing_context if isinstance(dataset, IterableDataset):
self._dataset_kind = _DatasetKind.Iterable
# 省略異常處理
else:
self._dataset_kind = _DatasetKind.Map if batch_sampler is not None:
# auto_collation with custom batch_sampler
# 省略異常處理
batch_size = None
drop_last = False
elif batch_size is None:
# no auto_collation
if drop_last:
raise ValueError('batch_size=None option disables auto-batching '
'and is mutually exclusive with drop_last') if sampler is None: # give default samplers
if self._dataset_kind == _DatasetKind.Iterable:
# See NOTE [ Custom Samplers and IterableDataset ]
sampler = _InfiniteConstantSampler()
else: # map-style
if shuffle:
sampler = RandomSampler(dataset, generator=generator)
else:
sampler = SequentialSampler(dataset) if batch_size is not None and batch_sampler is None:
# auto_collation without custom batch_sampler
batch_sampler = BatchSampler(sampler, batch_size, drop_last) self.batch_size = batch_size
self.drop_last = drop_last
self.sampler = sampler
self.batch_sampler = batch_sampler
self.generator = generator if collate_fn is None:
if self._auto_collation:
collate_fn = _utils.collate.default_collate
else:
collate_fn = _utils.collate.default_convert self.collate_fn = collate_fn
self.persistent_workers = persistent_workers
self.__initialized = True
self._IterableDataset_len_called = None
self._iterator = None
self.check_worker_number_rationality()

2.2 關鍵函式

這裡關鍵函式之一就是_index_sampler,用來讓迭代器呼叫sampler,我們接下來就會講到

    @property
def _index_sampler(self):
# The actual sampler used for generating indices for `_DatasetFetcher`
# (see _utils/fetch.py) to read data at each time. This would be
# `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise.
# We can't change `.sampler` and `.batch_sampler` attributes for BC
# reasons.
if self._auto_collation:
return self.batch_sampler
else:
return self.sampler

2.3 單程序載入

單程序模式下,Data Loader會在計算程序內載入資料,所以載入過程中可能會阻塞計算。

for 語句會呼叫enumerate 會返回一個迭代器,以此來遍歷資料集。在eumerate之中,dataloader 的 __next__(self) 方法會被呼叫,逐一獲取下一個物件,從而遍歷資料集。

    cuda0 = torch.device('cuda:0')  # CUDA GPU 0
for i, x in enumerate(train_loader):
x = x.to(cuda0)

2.3.1 區分生成

當多程序載入時候,在DataLoader宣告週期之中,迭代器只被建立一次,這樣worker可以重用迭代器。

在單程序載入時候,應該每次生成,以避免重置狀態。

    def __iter__(self) -> '_BaseDataLoaderIter':
if self.persistent_workers and self.num_workers > 0: # 如果是多程序或者設定了持久化
if self._iterator is None: # 如果沒有,才會新生成
self._iterator = self._get_iterator()
else:
self._iterator._reset(self)
return self._iterator
else: # 單程序
return self._get_iterator() # 每次都直接生成新的

具體會依據是否是多程序來區別生成。

    def _get_iterator(self) -> '_BaseDataLoaderIter':
if self.num_workers == 0:
return _SingleProcessDataLoaderIter(self)
else:
self.check_worker_number_rationality()
return _MultiProcessingDataLoaderIter(self)

2.3.2 迭代器基類

_BaseDataLoaderIter 是迭代器基類,我們挑選關鍵函式看看。

這裡關鍵成員變數就是:

  • _index_sampler:這裡設定了loader 的 sampler,所以迭代器可以據此獲取取樣策略。
  • _sampler_iter:得到 sampler 的迭代器。
class _BaseDataLoaderIter(object):
def __init__(self, loader: DataLoader) -> None:
# 初始化引數
self._dataset = loader.dataset
self._dataset_kind = loader._dataset_kind
self._IterableDataset_len_called = loader._IterableDataset_len_called
self._auto_collation = loader._auto_collation
self._drop_last = loader.drop_last
self._index_sampler = loader._index_sampler # 得到取樣策略
self._num_workers = loader.num_workers
self._prefetch_factor = loader.prefetch_factor
self._pin_memory = loader.pin_memory and torch.cuda.is_available()
self._timeout = loader.timeout
self._collate_fn = loader.collate_fn
self._sampler_iter = iter(self._index_sampler) # 得到sampler的迭代器
self._base_seed = torch.empty((), dtype=torch.int64).random_(generator=loader.generator).item()
self._persistent_workers = loader.persistent_workers
self._num_yielded = 0
self._profile_name = "enumerate(DataLoader)#{}.__next__".format(self.__class__.__name__) def __next__(self) -> Any:
with torch.autograd.profiler.record_function(self._profile_name):
if self._sampler_iter is None:
self._reset()
data = self._next_data() # 獲取資料
self._num_yielded += 1
if self._dataset_kind == _DatasetKind.Iterable and \
self._IterableDataset_len_called is not None and \
self._num_yielded > self._IterableDataset_len_called:
# 忽略錯誤提示處理
warnings.warn(warn_msg)
return data

2.3.3 單程序迭代器

_SingleProcessDataLoaderIter 繼承了 _BaseDataLoaderIter,可以看到,其增加了 _dataset_fetcher,在構造時候傳入了 _collate_fn 等各種引數。

回憶下,__next__會呼叫 self._next_data() 獲取資料,而在這裡,_next_data 就會:

  • 使用 self._next_index(),其又會使用 _sampler_iter(取樣器的迭代器)來獲取indices 。
  • 使用 self._dataset_fetcher.fetch(index)來依據indices獲取資料。
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
def __init__(self, loader):
super(_SingleProcessDataLoaderIter, self).__init__(loader)
assert self._timeout == 0
assert self._num_workers == 0 # 獲取樣本方法
self._dataset_fetcher = _DatasetKind.create_fetcher(
self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last) def _next_data(self):
index = self._next_index() # may raise StopIteration
# 獲取樣本
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
if self._pin_memory:
data = _utils.pin_memory.pin_memory(data)
return data def _next_index(self): # 得到indices
return next(self._sampler_iter) # may raise StopIteration

2.3.4 獲取樣本

我們接下來看看如何獲取樣本。就是通過索引傳入 fetcher,從而獲取想要的樣本。

fetcher生成如下,這是在_SingleProcessDataLoaderIter初始化時候生成的:

class _DatasetKind(object):
Map = 0
Iterable = 1 @staticmethod
def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last):
if kind == _DatasetKind.Map:
return _utils.fetch._MapDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
else:
return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)

對於Map-style,就使用 _MapDatasetFetcher 處理,就是使用 possibly_batched_index 從資料集之中提取資料,possibly_batched_index 是key。

如果有batch sampler,就使用 batch sampler。

如果需要從一個小批次( mini-batch)張量中合併出一個樣本列表。就使用 collate_fn後處理。

class _MapDatasetFetcher(_BaseDatasetFetcher):
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last) def fetch(self, possibly_batched_index):
if self.auto_collation:
# 如果配置了batch_sampler,_auto_collation就為True,
# 那麼就優先使用batch_sampler,此時fetcher中傳入的就是一個batch的索引
data = [self.dataset[idx] for idx in possibly_batched_index]
else:
data = self.dataset[possibly_batched_index]
return self.collate_fn(data)

對於 Iterable-style,因為 __init__ 方法內設定了 dataset 初始的迭代器,所以在fetch 方法內獲取元素的時候,如果是常規 sampler,index 其實已經不起作用,直接從dataset迭代器獲取。如果是batch sampler,則index有效果。

class _IterableDatasetFetcher(_BaseDatasetFetcher):
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
super(_IterableDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)
self.dataset_iter = iter(dataset) def fetch(self, possibly_batched_index):
if self.auto_collation:
# 即auto_collation為True,表示使用batch_sampler。
# 則使用possibly_batched_index,獲取1個batch大小的樣本
data = []
for _ in possibly_batched_index:
try:
data.append(next(self.dataset_iter))
except StopIteration:
break
if len(data) == 0 or (self.drop_last and len(data) < len(possibly_batched_index)):
raise StopIteration
else:
# sampler則直接往後遍歷,提取1個樣本
data = next(self.dataset_iter)
return self.collate_fn(data)

此時總邏輯如下:

     +--------------------------+            +-------------------------------+
| DataLoader | | _SingleProcessDataLoaderIter |
| | | |
| | | __next__ |
+---------------+ Sampler | | |
| | | | _next_data +-----------+
| | Dataset | | | |
| | | | _next_index | |
| | __iter__ | | | |
| | | | _index_sampler | |
| | _get_iterator +--------------> | + | |
| | | | | | |
| +--------------------------+ +-------------------------------+ |
| | |
| | |
| | |
| | |
| | |
| +----------------------------+ | |
| |Sampler | | |
+------------------------> | | <------+ |
| | |
| | |
| | |
+----------------------------+ |
|
|
+----------------------------+ |
|_BaseDatasetFetcher | |
| | |
| | |
| dataset | |
| | <----------------------+
| collate_fn |
| |
+----------------------------+

動態流程如下:

  User              DataLoader    _SingleProcessDataLoaderIter _DatasetKind   Sampler

    +                   +                    +                        +           +
| | | | |
| 1 | | | |
enumerate--------> __iter__ | | |
| + | | |
| | | | |
| | | | |
| | 2 v 3 v |
| _get_iterator--------> __init__ +----------> create_fetcher |
| 4 | + + |
| <-----------------+ | | |
| iterator | | | |
| | 5 | | |
for loop +------------------------------> __next__ | |
| | | | |
| | | | |
| | | | |
| | _next_data | |
| | | | |
| | | | |
| | | 6 next | |
| | _next_index +-------------------------> |
| | | | |
| | | <---------------------------------+
| | | 7 index | |
| | | | |
| | | | |
| | | 8 fetch(index) | |
| | | +--------------------> | |
| | | | |
| | | <---------------------+ |
| | | 9 data | |
| <-------------------------------------+ | |
| 10 data | | | |
| | | | |
v v v v v

2.4 多程序載入

為了加速,PyTorch提供了多程序下載,只要把將引數 num_workers 設定為正整數,系統就會相應生成多程序處理,在這種模式下,每個worker都是一個獨立程序。

由上節我們可以知道,_SingleProcessDataLoaderIter 是單程序載入資料的核心,loader通過它來與sampler,dataset互動。在多程序中,這個核心對應的就是 _MultiProcessingDataLoaderIter。

    def _get_iterator(self) -> '_BaseDataLoaderIter':
if self.num_workers == 0:
return _SingleProcessDataLoaderIter(self)
else:
self.check_worker_number_rationality()
return _MultiProcessingDataLoaderIter(self)

我們接下來就從 _MultiProcessingDataLoaderIter 開始分析。

2.4.1 總體邏輯

_MultiProcessingDataLoaderIter 中的註釋十分詳盡,值得大家深讀,而且給出了邏輯流程圖如下,其基本流程是圍繞著三個queue進行的:

  • 主程序把需要獲取的資料 index 放入index_queue,這是指定子程序需要獲取哪些資料的佇列。同時也給子程序傳入結果佇列,關於結果佇列,有兩個分支:

    • 如果設定了pin memory,則傳入的是 worker_result_queue。
    • 否則傳入 data_queue。
  • 子程序從 index_queue 之中讀取 index,進行資料讀取,然後把讀取資料的index放入worker_result_queue,這是向主程序返回結果的佇列
  • 主程序進行處理,這裡有兩個分支:
    • 如果設定了pin memory,則主程序的 pin_memory_thread 會從 worker_result_queue 讀取資料index,依據這個index進行讀取資料,進行處理,把結果放入 data_queue,這是處理結果的佇列
    • 如果不需要pin memory,則結果已經存在 data_queue 之中,不做新操作。

可以看到,每個程序的輸入是一個佇列index_queue ,輸出也是一個佇列worker_result_queue。主程序和子程序通過這2~3個 queue 聯絡了起來,從而達到解耦合和加速的作用

    # NOTE [ Data Loader Multiprocessing Shutdown Logic ]
#
# Preliminary:
#
# Our data model looks like this (queues are indicated with curly brackets):
#
# main process ||
# | ||
# {index_queue} ||
# | ||
# worker processes || DATA
# | ||
# {worker_result_queue} || FLOW
# | ||
# pin_memory_thread of main process || DIRECTION
# | ||
# {data_queue} ||
# | ||
# data output \/
#
# P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if
# `pin_memory=False`.

具體如下圖所示,如果不需要 pin memory,則為:

                                               +-----------+
indices -------------+ indices | Worker | Data
+--------->+index queue +-------->+ Process +------+
| | | | | |
| -------------+ +-----------+ |
| | +------------+
| | | |
+---------+ | +---> |
| Main | | indices -------------+ indices +-----------+ | |
| Process +------------>+index queue +-------->+ Worker | Data | Data Queue |
| | | | | | Process +----------> |
+---------+ | -------------+ | | | |
| +-----------+ +---> |
| | +------------+
| |
| indices -------------+ indices +-----------+ |
+--------->+index queue +-------->+ Worker | Data |
| | | Process +------+
-------------+ | |
+-----------+

當有pin memory時候,則是先進入 result queue,然後 pin_memory_thread 處理之後會轉入到 data queue:

                                               +-----------+
indices -------------+ indices | Worker | Data
+--------->+index queue +-------->+ Process +------+
| | | | | |
| -------------+ +-----------+ |
| | --------------+
| | | |
+---------+ | +---> |
| Main | | indices -------------+ indices +-----------+ | |
| Process +------------>+index queue +-------->+ Worker | Data | result_queue|
| | | | | | Process +----------> |
+---------+ | -------------+ | | | |
| +-----------+ +---> |
| | ---------+----+
| | |
| indices -------------+ indices +-----------+ | |
+--------->+index queue +-------->+ Worker | Data | +---------+--------+
| | | Process +------+ | pin_memory_thread|
-------------+ | | | | |
+-----------+ | | |
| | |
+------------------+
|
|
|
v
+-----+------+
| Data Queue |
| |
+------------+

2.4.2 初始化

初始化函式如下,主要是:

  • 配置,生成各種成員變數,配置各種queue。
  • 啟動各個子程序。
  • 啟動主程序中的pin_memory的執行緒。

主要成員變數為:

  • _index_queues: 這是一個queue 列表,列表的每一個元素是一個 queue,就是每個子程序的佇列需要處理的資料index,每個子程序對應一個 queue。
  • _worker_result_queue: 子程序處理完的 (idx, data)。
  • data_queue: 經過主程序 pin_memory 執行緒處理之後的資料佇列,如果不需要pin,則直接會使用 _worker_result_queue
  • _worker_queue_idx_cycle 用以找出下一個工作的worker。

具體程式碼如下:

class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
r"""Iterates once over the DataLoader's dataset, as specified by the sampler""" def __init__(self, loader):
super(_MultiProcessingDataLoaderIter, self).__init__(loader) assert self._num_workers > 0
assert self._prefetch_factor > 0 if loader.multiprocessing_context is None:
multiprocessing_context = multiprocessing
else:
multiprocessing_context = loader.multiprocessing_context self._worker_init_fn = loader.worker_init_fn
self._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers))
# No certainty which module multiprocessing_context is
self._worker_result_queue = multiprocessing_context.Queue() # 子程序輸出,讀取完資料的index
self._worker_pids_set = False
self._shutdown = False
self._workers_done_event = multiprocessing_context.Event() self._index_queues = [] # 子程序輸入,需讀取資料的index
self._workers = []
for i in range(self._num_workers):
# No certainty which module multiprocessing_context is
index_queue = multiprocessing_context.Queue() # type: ignore[var-annotated]
# Need to `cancel_join_thread` here!
# See sections (2) and (3b) above.
index_queue.cancel_join_thread()
w = multiprocessing_context.Process(
target=_utils.worker._worker_loop, # worker程序主函式,把各種queue和函式傳進去
args=(self._dataset_kind, self._dataset, index_queue,
self._worker_result_queue, self._workers_done_event,
self._auto_collation, self._collate_fn, self._drop_last,
self._base_seed, self._worker_init_fn, i, self._num_workers,
self._persistent_workers))
w.daemon = True
w.start()
self._index_queues.append(index_queue) # 把這個worker對應的index_queue放到主程序這裡存起來,以後就可以互動了
self._workers.append(w) if self._pin_memory:
self._pin_memory_thread_done_event = threading.Event() # Queue is not type-annotated
self._data_queue = queue.Queue() # pin 處理之後的資料結果
pin_memory_thread = threading.Thread(
target=_utils.pin_memory._pin_memory_loop,
args=(self._worker_result_queue, self._data_queue,
torch.cuda.current_device(),
self._pin_memory_thread_done_event))
pin_memory_thread.daemon = True
pin_memory_thread.start()
# Similar to workers (see comment above), we only register
# pin_memory_thread once it is started.
self._pin_memory_thread = pin_memory_thread
else:
self._data_queue = self._worker_result_queue # 如果不需要pin,則直接使用_worker_result_queue # .pid can be None only before process is spawned (not the case, so ignore)
_utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers)) # type: ignore[misc]
_utils.signal_handling._set_SIGCHLD_handler()
self._worker_pids_set = True self._reset(loader, first_iter=True) # 繼續完善業務

2.4.3 業務重置

__init__ 函式最後會呼叫 _reset 函式,這是進一步完善業務初始化,也用來重置環境。

上小節函式中,已經啟動了worker子程序,但是沒有分配任務,所以_reset函式會進行任務分配,預取。

_MultiProcessingDataLoaderIter有如下 flag 引數來協調各個 worker (包括各種queue)之間的工作:

  • _send_idx: 傳送索引,用來記錄這次要放 index_queue 中 batch 的 idx

  • _rcvd_idx: 接受索引,記錄要從 data_queue 中取出的 batch 的 idx

  • _task_info: 儲存將要產生的 data 資訊的 dict,key為 task idx(由 0 開始的整型索引),value 為 (worker_id,)(worker_id, data),分別對應資料 未取 和 已取 的情況

  • _tasks_outstanding: 整型,代表已經準備好的 task/batch 的數量(可能有些正在準備中)

  • _send_idx: 傳送索引,記錄下一次要放 index_queue 中 task batch 的 idx。

  • _rcvd_idx: 接受索引,記錄下一次要從 data_queue 中取出的 task batch 的 idx。_send_idx_rcvd_idx 主要用來進行流量控制和確保接受索引有意義。

  • _task_info: 儲存將要產生的 data 資訊的 dict,key為 task batch idx(由 0 開始的整型索引),value 為 (worker_id,)(worker_id, data),分別對應資料 未取 和 已取 的情況。_task_info的作用是依據 task batch idx 獲取對應的 worker id 和暫存亂序資料。

  • _tasks_outstanding: 整型,正在準備的 task/batch 的數量,實際上就是進行一些確認工作,沒有太實際的意義。

對於載入資料,每個 worker 一次產生一個 batch 的資料,返回 batch 資料前,會放入下一個批次要處理的資料下標,所以 reset 函式會把 _send_idx_rcvd_idx 都恢復成0,這樣下次迭代就可以重新處理。

在 reset 方法最後,有一個預取資料操作。我們會在後面結合亂序處理進行講解

    def _reset(self, loader, first_iter=False):
super()._reset(loader, first_iter)
self._send_idx = 0 # idx of the next task to be sent to workers
self._rcvd_idx = 0 # idx of the next task to be returned in __next__
# information about data not yet yielded, i.e., tasks w/ indices in range [rcvd_idx, send_idx).
# map: task idx => - (worker_id,) if data isn't fetched (outstanding)
# \ (worker_id, data) if data is already fetched (out-of-order)
self._task_info = {}
self._tasks_outstanding = 0 # always equal to count(v for v in task_info.values() if len(v) == 1)
# A list of booleans representing whether each worker still has work to
# do, i.e., not having exhausted its iterable dataset object. It always
# contains all `True`s if not using an iterable-style dataset
# (i.e., if kind != Iterable).
# Not that this indicates that a worker still has work to do *for this epoch*.
# It does not mean that a worker is dead. In case of `_persistent_workers`,
# the worker will be reset to available in the next epoch.
# 每個worker的狀態
self._workers_status = [True for i in range(self._num_workers)]
# We resume the prefetching in case it was enabled
if not first_iter:
for idx in range(self._num_workers):
self._index_queues[idx].put(_utils.worker._ResumeIteration())
resume_iteration_cnt = self._num_workers
while resume_iteration_cnt > 0:
return_idx, return_data = self._get_data()
if isinstance(return_idx, _utils.worker._ResumeIteration):
assert return_data is None
resume_iteration_cnt -= 1
# prime the prefetch loop # 預取若干index,目的是為了配合後續的亂序處理。
for _ in range(self._prefetch_factor * self._num_workers):
self._try_put_index()

2.4.4 獲取 index

_try_put_index 函式就是使用sampler獲取下一批次的資料index。這裡 _prefetch_factor 預設值是 2,主要邏輯如下。

  • 從sampler獲取下一批次的index。
  • 通過 _worker_queue_idx_cycle 找出下一個可用的工作worker,然後把index分給它。
  • 並且調整主程序的資訊。
    def _next_index(self): # 定義在基類 _BaseDataLoaderIter 之中,就是獲取下一批index
return next(self._sampler_iter) # may raise StopIteration def _try_put_index(self): assert self._tasks_outstanding < self._prefetch_factor * self._num_workers try:
index = self._next_index() # 獲取下一批index
except StopIteration:
return
for _ in range(self._num_workers): # find the next active worker, if any
worker_queue_idx = next(self._worker_queue_idx_cycle)
if self._workers_status[worker_queue_idx]: # 如果已經工作,就繼續找
break
else:
# not found (i.e., didn't break)
return # 以下是主程序進行相關記錄
# 給下一個工作worker放入 (任務index, 資料index), 就是給queue放入資料,所以worker loop之中就立刻會從queue中得到index,從而開始獲取資料。
self._index_queues[worker_queue_idx].put((self._send_idx, index))
# 記錄 將要產生的 data 資訊
self._task_info[self._send_idx] = (worker_queue_idx,)
# 正在處理的batch個數+1
self._tasks_outstanding += 1
# send_idx 記錄從sample_iter中傳送索引到index_queue的次數
self._send_idx += 1 # 遞增下一批發送的task index

2.4.5 worker主函式

_worker_loop 是 worker程序的主函式,主要邏輯如其註釋所示:

    # [ worker processes ]
# While loader process is alive:
# Get from `index_queue`.
# If get anything else,
# Check `workers_done_event`.
# If set, continue to next iteration
# i.e., keep getting until see the `None`, then exit.
# Otherwise, process data:
# If is fetching from an `IterableDataset` and the iterator
# is exhausted, send an `_IterableDatasetStopIteration`
# object to signal iteration end. The main process, upon
# receiving such an object, will send `None` to this
# worker and not use the corresponding `index_queue`
# anymore.
# If timed out,
# No matter `workers_done_event` is set (still need to see `None`)
# or not, must continue to next iteration.
# (outside loop)
# If `workers_done_event` is set, (this can be False with `IterableDataset`)
# `data_queue.cancel_join_thread()`. (Everything is ending here:
# main process won't read from it;
# other workers will also call
# `cancel_join_thread`.)

就是通過index_queue, data_queue與主程序互動。

  • 從 index_queue 獲取新的資料index;
  • 如果沒有設定本worker結束,就使用 fetcher獲取資料
  • 然後把資料放入data_queue,並且通知主程序,這裡需要注意,data_queue是傳入的引數,如果設定了pin memory,則傳入的是 worker_result_queue, 否則傳入 data_queue
def _worker_loop(dataset_kind, dataset, index_queue, data_queue, done_event,
auto_collation, collate_fn, drop_last, base_seed, init_fn, worker_id,
num_workers, persistent_workers):
# See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
# logic of this function. try:
# Initialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
# module's handlers are executed after Python returns from C low-level
# handlers, likely when the same fatal signal had already happened
# again.
# https://docs.python.org/3/library/signal.html#execution-of-python-signal-handlers
signal_handling._set_worker_signal_handlers() torch.set_num_threads(1)
seed = base_seed + worker_id
random.seed(seed)
torch.manual_seed(seed)
if HAS_NUMPY:
np_seed = _generate_state(base_seed, worker_id)
import numpy as np
np.random.seed(np_seed) global _worker_info
_worker_info = WorkerInfo(id=worker_id, num_workers=num_workers,
seed=seed, dataset=dataset) from torch.utils.data import _DatasetKind init_exception = None try:
if init_fn is not None:
init_fn(worker_id) fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last)
except Exception:
init_exception = ExceptionWrapper(
where="in DataLoader worker process {}".format(worker_id)) iteration_end = False
watchdog = ManagerWatchdog() while watchdog.is_alive(): # 等待在這裡
try:
# _try_put_index 如果放入了資料index,這裡就被啟用,開始工作
r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
except queue.Empty:
continue
if isinstance(r, _ResumeIteration):
# Acknowledge the main process
data_queue.put((r, None))
iteration_end = False
# Recreate the fetcher for worker-reuse policy
fetcher = _DatasetKind.create_fetcher(
dataset_kind, dataset, auto_collation, collate_fn, drop_last)
continue
elif r is None:
# Received the final signal
assert done_event.is_set() or iteration_end
break
elif done_event.is_set() or iteration_end:
# `done_event` is set. But I haven't received the final signal
# (None) yet. I will keep continuing until get it, and skip the
# processing steps.
continue
idx, index = r
data: Union[_IterableDatasetStopIteration, ExceptionWrapper]
if init_exception is not None:
data = init_exception
init_exception = None
else:
try:
data = fetcher.fetch(index)
except Exception as e:
# 省略處理程式碼 data_queue.put((idx, data)) # 放入資料,通知主程序
del data, idx, index, r # save memory
except KeyboardInterrupt:
# Main process will raise KeyboardInterrupt anyways.
pass
if done_event.is_set():
data_queue.cancel_join_thread()
data_queue.close()

2.4.6 Pin memory thread

在主程序之中,如果設定了需要pin memory,主程序的 pin_memory_thread 會從 worker_result_queue 讀取資料,進行處理(加速CPU和GPU的資料拷貝),把結果放入 data_queue。

    # [ pin_memory_thread ]
# # No need to check main thread. If this thread is alive, the main loader
# # thread must be alive, because this thread is set as daemonic.
# While `pin_memory_thread_done_event` is not set:
# Get from `index_queue`.
# If timed out, continue to get in the next iteration.
# Otherwise, process data.
# While `pin_memory_thread_done_event` is not set:
# Put processed data to `data_queue` (a `queue.Queue` with blocking put)
# If timed out, continue to put in the next iteration.
# Otherwise, break, i.e., continuing to the out loop.
#
# NOTE: we don't check the status of the main thread because
# 1. if the process is killed by fatal signal, `pin_memory_thread`
# ends.
# 2. in other cases, either the cleaning-up in __del__ or the
# automatic exit of daemonic thread will take care of it.
# This won't busy-wait either because `.get(timeout)` does not
# busy-wait.

具體程式碼如下:

def _pin_memory_loop(in_queue, out_queue, device_id, done_event):
# This setting is thread local, and prevents the copy in pin_memory from
# consuming all CPU cores.
torch.set_num_threads(1) torch.cuda.set_device(device_id) # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
# logic of this function.
while not done_event.is_set():
try:
r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
except queue.Empty:
continue
idx, data = r
if not done_event.is_set() and not isinstance(data, ExceptionWrapper):
data = pin_memory(data)
# 省略異常處理程式碼
r = (idx, data)
while not done_event.is_set():
try:
out_queue.put(r, timeout=MP_STATUS_CHECK_INTERVAL)
break
except queue.Full:
continue
del r # save memory def pin_memory(data):
if isinstance(data, torch.Tensor):
return data.pin_memory()
elif isinstance(data, string_classes):
return data
elif isinstance(data, collections.abc.Mapping):
return {k: pin_memory(sample) for k, sample in data.items()}
elif isinstance(data, tuple) and hasattr(data, '_fields'): # namedtuple
return type(data)(*(pin_memory(sample) for sample in data))
elif isinstance(data, collections.abc.Sequence):
return [pin_memory(sample) for sample in data]
elif hasattr(data, "pin_memory"):
return data.pin_memory()
else:
return data

2.4.7 使用者獲取data

現在資料已經載入完畢,我們接下來看使用者如何從DataLoader之中獲取資料。

這裡有一個很關鍵的地方:如何保持在不同實驗之中資料讀取順序的一致性。為了讓多次實驗之間可以比對,就需要儘量保證在這些實驗中,每次讀取資料的順序都是一致的,這樣才不會因為資料原因造成結果的誤差。

打破順序一致性的最大可能就是亂序資料。而造成亂序問題的原因就是:多程序讀取,可能某個程序快,某個程序慢。比如,使用者這次需要讀取6-19,16-26,37-46。但是某一個worker慢,6-19不能即時返回,另一個worker 的 16-26 先返回了,於是就會造成亂序。

如何處理亂序資料?PyTorch的具體做法就是:DataLoader嚴格按照Sampler的順序返回資料。如果某一個數據是亂序的,則會把它暫存起來,轉而去獲取下一個資料,見下面程式碼中 "store out-of-order samples" 註釋處。等到應該返回時候(這個資料順序到了)才返回。

但是其風險就是資料返回會比當前請求慢,比如應該獲取 6,但是Data queue裡面沒有這個資料,只有 16,27,於是使用者只能等待 6 載入完成。

解決慢的方法是:預取(prefetch)。就是在reset方法最後,提前提取若干index,讓DataLoader提前去取,這雖然不能保證任意兩次訓練的資料返回順序完全一致,但是可以最大限度保證。

具體程式碼如下,首先,回憶基類的 __next__ 函式 ,可以看到其呼叫了 _next_data 獲取資料。

class _BaseDataLoaderIter(object):
def __next__(self) -> Any:
with torch.autograd.profiler.record_function(self._profile_name):
if self._sampler_iter is None:
self._reset()
data = self._next_data() # 獲取資料
self._num_yielded += 1
if self._dataset_kind == _DatasetKind.Iterable and \
self._IterableDataset_len_called is not None and \
self._num_yielded > self._IterableDataset_len_called:
# 忽略錯誤提示處理
warnings.warn(warn_msg)
return data

所以,我們要看 _MultiProcessingDataLoaderIter_next_data

  • 因為之前有預取了index,worker程序已經開始獲取資料,所以主程序此時可以得到資料,如果沒有資料,就繼續while True等待。
  • 如果獲取成功,則使用 _process_data 設定下一次的indx,準備下一次迭代。
  • 通過 _task_info 來記錄亂序資料,如果暫時無法處理,就在這裡儲存。
    def _next_data(self):
while True:
# If the worker responsible for `self._rcvd_idx` has already ended
# and was unable to fulfill this task (due to exhausting an `IterableDataset`),
# we try to advance `self._rcvd_idx` to find the next valid index.
#
# This part needs to run in the loop because both the `self._get_data()`
# call and `_IterableDatasetStopIteration` check below can mark
# extra worker(s) as dead. # 找到待取idx
while self._rcvd_idx < self._send_idx: # 如果 待取batch idx < 已取batch idx
info = self._task_info[self._rcvd_idx]
worker_id = info[0]
if len(info) == 2 or self._workers_status[worker_id]: # has data or is still active
break # 有資料或者正在工作,就跳出內部這個while
del self._task_info[self._rcvd_idx]
self._rcvd_idx += 1
else:
# no valid `self._rcvd_idx` is found (i.e., didn't break)
if not self._persistent_workers:
self._shutdown_workers()
raise StopIteration # Now `self._rcvd_idx` is the batch index we want to fetch # Check if the next sample has already been generated
if len(self._task_info[self._rcvd_idx]) == 2:
data = self._task_info.pop(self._rcvd_idx)[1]
return self._process_data(data) # 設定下一次的indx,進行下一次迭代 assert not self._shutdown and self._tasks_outstanding > 0
idx, data = self._get_data() # 從 self._data_queue 中取資料
self._tasks_outstanding -= 1 # 正在準備的batch個數需要減1 if self._dataset_kind == _DatasetKind.Iterable:
# Check for _IterableDatasetStopIteration
if isinstance(data, _utils.worker._IterableDatasetStopIteration):
if self._persistent_workers:
self._workers_status[data.worker_id] = False
else:
self._mark_worker_as_unavailable(data.worker_id)
self._try_put_index()
continue if idx != self._rcvd_idx: # 亂序資料
# store out-of-order samples
self._task_info[idx] += (data,)
else:
del self._task_info[idx] # 正常資料
return self._process_data(data) # 設定下一次的indx,進行下一次迭代

其次,我們看看 _get_data 如何從 self._data_queue 中取資料。具體是使用 _try_get_data 來提取。

  • 如果有超時配置,就按照超時讀取。
  • 如果設定了pin memory,則從pin 執行緒處理之後的資料讀取。
  • 否則迴圈讀取worker處理的資料,直至獲取到資料為止。
    def _get_data(self):
# Fetches data from `self._data_queue`.
#
# We check workers' status every `MP_STATUS_CHECK_INTERVAL` seconds,
# which we achieve by running `self._try_get_data(timeout=MP_STATUS_CHECK_INTERVAL)`
# in a loop. This is the only mechanism to detect worker failures for
# Windows. For other platforms, a SIGCHLD handler is also used for
# worker failure detection.
#
# If `pin_memory=True`, we also need check if `pin_memory_thread` had
# died at timeouts.
if self._timeout > 0: # 如果有超時配置,就按照超時讀取
success, data = self._try_get_data(self._timeout)
if success:
return data
else:
raise RuntimeError('DataLoader timed out after {} seconds'.format(self._timeout))
elif self._pin_memory: # 從pin 執行緒處理之後的資料讀取
while self._pin_memory_thread.is_alive():
success, data = self._try_get_data()
if success:
return data
else:
# while condition is false, i.e., pin_memory_thread died.
raise RuntimeError('Pin memory thread exited unexpectedly')
# In this case, `self._data_queue` is a `queue.Queue`,. But we don't
# need to call `.task_done()` because we don't use `.join()`.
else:
while True:
success, data = self._try_get_data() # 讀取worker處理的資料
if success:
return data

_try_get_data 就是從 _data_queue 讀取。主程序和worker程序通過queue上的put, get進行通訊互動。

    def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL):
# Tries to fetch data from `self._data_queue` once for a given timeout.
# This can also be used as inner loop of fetching without timeout, with
# the sender status as the loop condition.
#
# This raises a `RuntimeError` if any worker died expectedly. This error
# can come from either the SIGCHLD handler in `_utils/signal_handling.py`
# (only for non-Windows platforms), or the manual check below on errors
# and timeouts.
#
# Returns a 2-tuple:
# (bool: whether successfully get data, any: data if successful else None)
try:
data = self._data_queue.get(timeout=timeout)
return (True, data)
except Exception as e:
# At timeout and error, we manually check whether any worker has
# failed. Note that this is the only mechanism for Windows to detect
# worker failures.
failed_workers = []
for worker_id, w in enumerate(self._workers):
if self._workers_status[worker_id] and not w.is_alive():
failed_workers.append(w)
self._mark_worker_as_unavailable(worker_id)
# 省略異常處理程式碼
import tempfile
import errno
try:
# Raise an exception if we are this close to the FDs limit.
# Apparently, trying to open only one file is not a sufficient
# test.
# See NOTE [ DataLoader on Linux and open files limit ]
fds_limit_margin = 10
fs = [tempfile.NamedTemporaryFile() for i in range(fds_limit_margin)]
except OSError as e:
# 省略異常處理程式碼
raise

設定下一次迭代是使用_process_data

    def _process_data(self, data):
self._rcvd_idx += 1
self._try_put_index() # 設定下一次的indx,進行下一次迭代
if isinstance(data, ExceptionWrapper):
data.reraise()
return data # 返回資料

2.4.8 小結

我們小結一下多程序邏輯。

總體邏輯如下:

  • 主程序把需要獲取的資料 index 放入index_queue。
  • 子程序從 index_queue 之中讀取 index,進行資料讀取,然後把讀取資料的index放入worker_result_queue。
  • 主程序的 pin_memory_thread 會從 worker_result_queue 讀取資料index,依據這個index進行讀取資料,進行處理,把結果放入 data_queue。

具體流程如下圖:

  1. 在 _MultiProcessingDataLoaderIter 的初始化函式 __init__ 之中會進行初始化:

    • 配置,生成各種成員變數,配置各種queue。
    • 啟動各個子程序。
    • 啟動主程序中的pin_memory的執行緒。
    • 呼叫 _reset 函式,這是進一步完善業務初始化,也用來重置環境。上面已經啟動了worker子程序,但是沒有分配任務,所以reset函式會進行任務分配,預取
  2. 接下來是一個預取操作(在看下圖中一定要留意)。
    • _try_put_index 函式就是使用sampler獲取下一批次的資料index。這裡 _prefetch_factor 預設值是 2,主要邏輯如下。

      • 使用 _next_index 從sampler獲取下一批次的index。
      • 通過 _worker_queue_idx_cycle 找出下一個可用的工作worker,然後把index分給它。
      • 並且調整主程序的資訊。
    • 拿到index之後,回到主執行緒。這裡會進行資料提取。就是通過index_queue, data_queue與主程序互動。
      • 從 index_queue 獲取新的資料index;
      • 如果沒有設定本worker結束,就使用 fetcher獲取資料。
      • 然後把資料放入data_queue,並且通知主程序,這裡需要注意,data_queue是傳入的引數,如果設定了pin memory,則傳入的是 worker_result_queue,否則傳入 data_queue。
  3. 當用戶迭代時,呼叫了Loader基類的 __next__ 函式 ,其呼叫 _next_data 從 DataLoader 之中獲取資料。
    • 使用 _get_data 如何從 self._data_queue 中取資料。
    • 使用_process_data 設定下一次迭代的 index,即使用 _try_put_index_next_index 來進行下一輪設定。

具體如下圖:

user        _MultiProcessingDataLoaderIter   Sampler        Queue(index_queue)    Queue(data_queue)    _worker_loop     Fetcher
+ + + + + + +
| | | | | | |
| | | | | | |
| v | | | | |
| __init__ | | | | |
| 1 _reset | | | | |
| + | | | | |
| | | | | | |
| | | | | | |
| v | | | | |
| 2 _try_put_index next | | | | |
| _next_index +------------> | | | | |
| + | | | | |
| | <-----------------+ | | | | |
| | index | | | | |
| | | | | | |
| | +------------------------------------> | | | |
| | put | | | get | |
| | | +--------------------------------------> | |
| | | | | | index |
| | | | | +------------> |
| next | | | | | <----------+ |
+---------------------> | | | | <----------------+ data |
| | | | | data | |
| + | | | | |
| _next_data | | | | |
| 3 _get_data get | | | | |
| _try_get_data +--------------------------------------------------> | | |
| + | | | | |
| | <----------------------------------------------------------+ | | |
| | data | | | | |
| + | | | | |
| _process_data | | | | |
| _try_put_index next | | | | |
| _next_index +-------------> | | | | |
| + <--------------------+ | | | |
| | index | | | | |
| +---------------------------------------> | | get | |
| <-------------------+ | put | +-------------------------------------> | index |
| data | | | | | +----------> |
| | | | | +<-----------+ |
v v v v v v data v

手機上如下:

2.5 Pipleline

至此,我們把之前的pipeline圖進一步細化,具體如下:

                                                  +------------+
+--------+ | |
| | | Process 1 |
+-----> | Data 1 +--------> | +------+
| | | | Load Data | |
| +--------+ | | |
| +------------+ |
| |
| |
| |
+----------------+ | +------------+ | +-------------------------+
|Main process | | +--------+ | | | | pin_memory_thread |
| | | | | | Process 2 | +------> +------------------------+ | | +------------+
| index_queue +----------> | Data 2 +--------> | | | | | | | |
| | | | | | Load Data +-------------> | _worker_result_queue +-----> | Write to pinned memory +--------> | data_queue |
| | | +--------+ | | | | | | | |
+----------------+ | +------------+ +-----> | | | | +------------+
| | +------------------------+ | |
| | +-------------------------+
| |
| +--------+ +------------+ |
| | | | | |
+-----> | Data 3 +--------> | Process 3 +-------+
| | | |
+--------+ | Load Data |
| |
+------------+

手機如下:

至此,PyTorch 分散式的資料載入部分分析完畢,下一篇我們迴歸到 Paracel 如何處理資料載入。

0xFF 參考

卷積神經網路的並行化模型--One weird trick for parallelizing convolutional neural networks

AI框架中資料處理的挑戰與解決思路

PyTorch 原始碼解讀之 torch.utils.data:解析資料處理全流程

談談你對大規模機器學習這個領域的理解和認識?

Nvidia-DALI 從放棄到入門

pytorch(分散式)資料並行個人實踐總結——DataParallel/DistributedDataParallel

Pytorch資料Pipeline設計總結

深度學習框架資料Pipeline設計