1. 程式人生 > >Pytorch入門學習(四)---- 多GPU的使用

Pytorch入門學習(四)---- 多GPU的使用

DataParrallel

import torch.nn as nn


class DataParallelModel(nn.Module):

    def __init__(self):
        super().__init__()
        self.block1 = nn.Linear(10, 20)

        # wrap block2 in DataParallel
        self.block2 = nn.Linear(20, 20)
        self.block2 = nn.DataParallel(self.block2)

        self.block3 = nn.Linear(20
, 20) def forward(self, x): x = self.block1(x) x = self.block2(x) x = self.block3(x) return x

這程式碼在CPU模式下也不需要改變。
DataParrallel中有一些基本型別。
- replicate: 將一個module複製到多個裝置上
- scatter: 將輸入第一維分配到不同GPU上。
- gather: gather and concatenate輸入的第一維度
- parralel_apply: 將已經分配的輸入用於一系列已分配的模型上。

def data_parallel(module, input, device_ids, output_device=None):
    if not device_ids:
        return module(input)

    if output_device is None:
        output_device = device_ids[0]

    replicas = nn.parallel.replicate(module, device_ids)
    inputs = nn.parallel.scatter(input, device_ids)
    replicas = replicas[:len(inputs)]
    outputs = nn.parallel.parallel_apply(replicas, inputs)
    return
nn.parallel.gather(outputs, output_device)

將模型部分放在CPU,部分放在GPU

這個太強大了。。

class DistributedModel(nn.Module):

    def __init__(self):
        super().__init__(
            embedding=nn.Embedding(1000, 10),
            rnn=nn.Linear(10, 10).cuda(0),
        )

    def forward(self, x):
        # Compute embedding on CPU
        x = self.embedding(x)

        # Transfer to GPU
        x = x.cuda(0)

        # Compute RNN on GPU
        x = self.rnn(x)
        return x