Pytorch入門學習(四)---- 多GPU的使用
阿新 • • 發佈:2019-01-23
DataParrallel
import torch.nn as nn
class DataParallelModel(nn.Module):
def __init__(self):
super().__init__()
self.block1 = nn.Linear(10, 20)
# wrap block2 in DataParallel
self.block2 = nn.Linear(20, 20)
self.block2 = nn.DataParallel(self.block2)
self.block3 = nn.Linear(20 , 20)
def forward(self, x):
x = self.block1(x)
x = self.block2(x)
x = self.block3(x)
return x
這程式碼在CPU模式下也不需要改變。
DataParrallel中有一些基本型別。
- replicate: 將一個module複製到多個裝置上
- scatter: 將輸入第一維分配到不同GPU上。
- gather: gather and concatenate輸入的第一維度
- parralel_apply: 將已經分配的輸入用於一系列已分配的模型上。
def data_parallel(module, input, device_ids, output_device=None):
if not device_ids:
return module(input)
if output_device is None:
output_device = device_ids[0]
replicas = nn.parallel.replicate(module, device_ids)
inputs = nn.parallel.scatter(input, device_ids)
replicas = replicas[:len(inputs)]
outputs = nn.parallel.parallel_apply(replicas, inputs)
return nn.parallel.gather(outputs, output_device)
將模型部分放在CPU,部分放在GPU
這個太強大了。。
class DistributedModel(nn.Module):
def __init__(self):
super().__init__(
embedding=nn.Embedding(1000, 10),
rnn=nn.Linear(10, 10).cuda(0),
)
def forward(self, x):
# Compute embedding on CPU
x = self.embedding(x)
# Transfer to GPU
x = x.cuda(0)
# Compute RNN on GPU
x = self.rnn(x)
return x