1. 程式人生 > >PyTorch深度學習框架入門——使用PyTorch實現手寫數字識別

PyTorch深度學習框架入門——使用PyTorch實現手寫數字識別

本文你將會了解到

1、如何使用PyTorch對資料集進行匯入 

2、如何使用PyTorch搭建一個簡易的深度學習模型 

3、如何優化和訓練我們搭建好的模型 

注:本案例使用的PyTorch為0.4版本

簡介

Pytorch是目前非常流行的深度學習框架,因為它具備了Python的特性所以極易上手和使用,同時又兼具了NumPy的特性,因此在效能上也並不遜於任何一款深度學習框架。現在PyTorch又和Caffe2進行了融合,在今年暑期整和了Caffe2的PyTorch1.0版本將受到更多專業人士的關注和重視。下面我們通過使用PyTorch實現一個手寫數字識別的模型來簡單的入門一下PyTorch。

如何使用PyTorch對資料集進行匯入

在進行資料匯入之前我們需要先匯入一些在之後需要用的包,程式碼如下:

  1. import torch

  2. import torchvision

  3. from torch.autograd importVariable

  4. from torchvision import transforms,datasets

  5. import numpy as np

  6. import matplotlib.pyplot as plt

  7. %matplotlib inline

手寫數字的資料集我們可以使用PyTorch中自帶的torchvision.datasets方法進行下載。另外這個方法還可以方便的下載COCO,ImageNet,CIFCAR等常用的資料集。程式碼如下:

  1. transform = transforms.ToTensor()                            train_dataset = datasets.MNIST(root ="./data",

  2.                               train =True,

  3.                               transform = transform,

  4.                               download =True)

  5. test_dataset = datasets.MNIST(root ="./data",

  6.                              train

    =False,

  7.                              transform = transform,

  8.                              download =True)

  9. train_data_loader = torch.utils.data.DataLoader(dataset = train_dataset,

  10.                                        batch_size=64,

  11. shuffle =True)

  12. test_data_loader = torch.utils.data.DataLoader(dataset=test_dataset,

  13. batch_size=64,

  14. shuffle =True)

通過這幾段程式碼我們就實現了資料的載入和裝載。其中batch_size指定我們每次裝載的資料個數,這裡使用的值是64即我們每次裝載到模型中的圖片個數是64張。shuffle設定為True表明我們裝載到模型中的輸入資料是被隨機打亂順序的。

如何使用PyTorch搭建一個簡易的深度學習模型

定義好了資料載入和裝載的方法之後,我們就可以開始搭建深度學習模型,這裡使用卷積層、最大池化層和全連線層來搭建一個簡易的卷積神經網路模型,程式碼如下:

  1. classModel(torch.nn.Module):

  2. def __init__(self):

  3. super(Model,self).__init__()

  4. self.Conv=torch.nn.Sequential(

  5. torch.nn.Conv2d(1,28,kernel_size=3,padding=1,stride=1),

  6.                                        torch.nn.Conv2d(28,64,kernel_size=3,padding=1,stride=1),                              torch.nn.MaxPool2d(kernel_size=2, stride=2),      torch.nn.Conv2d(64,64,kernel_size=3,padding=1,stride=1),                               torch.nn.Conv2d(64,64,kernel_size=3,padding=1,stride=1),                                 torch.nn.MaxPool2d(kernel_size=2, stride=2)

  7. )

  8. self.Dense= torch.nn.Linear(7*7*64,10)

  9. def forward(self, input):

  10.        x =self.Conv(input)

  11.        x = x.view(-1,7*7*64)

  12.        x =self.Dense(x)

  13. return x                           ```                  

模型使用的是4層卷積、2層池化和1層全連線的卷積神經網路結構。雖然結構簡單,但是對於處理手寫數字識別問題這個模型已經綽綽有餘了。搭建好模型之後我們可以對模型進行列印輸出,檢視具體的模型細節,如下:

  1. model =Model()

  2. model = model.cuda()

  3. print(model)

輸出的結果為:

  1. Model(

  2. (Conv):Sequential(

  3. (0):Conv2d(1,28, kernel_size=(3,3), stride=(1,1), padding=(1,1))

  4. (1):Conv2d(28,64, kernel_size=(3,3), stride=(1,1), padding=(1,1))

  5. (2):MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)

  6. (3):Conv2d(64,64, kernel_size=(3,3), stride=(1,1), padding=(1,1))

  7. (4):Conv2d(64,64, kernel_size=(3,3), stride=(1,1), padding=(1,1))

  8. (5):MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)

  9. )

  10. (Dense):Linear(in_features=3136, out_features=10, bias=True)

  11. )

如何優化和訓練我們搭建好的模型

模型已經搭建好了,資料的裝載方式也已經完成了定義,下面就差最後一步了,對我們的模型進行訓練並優化模型內部的引數。程式碼如下: