PyTorch深度學習框架入門——使用PyTorch實現手寫數字識別
本文你將會了解到
1、如何使用PyTorch對資料集進行匯入
2、如何使用PyTorch搭建一個簡易的深度學習模型
3、如何優化和訓練我們搭建好的模型
注:本案例使用的PyTorch為0.4版本
簡介
Pytorch是目前非常流行的深度學習框架,因為它具備了Python的特性所以極易上手和使用,同時又兼具了NumPy的特性,因此在效能上也並不遜於任何一款深度學習框架。現在PyTorch又和Caffe2進行了融合,在今年暑期整和了Caffe2的PyTorch1.0版本將受到更多專業人士的關注和重視。下面我們通過使用PyTorch實現一個手寫數字識別的模型來簡單的入門一下PyTorch。
如何使用PyTorch對資料集進行匯入
在進行資料匯入之前我們需要先匯入一些在之後需要用的包,程式碼如下:
import torch
import torchvision
from torch.autograd importVariable
from torchvision import transforms,datasets
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
手寫數字的資料集我們可以使用PyTorch中自帶的torchvision.datasets方法進行下載。另外這個方法還可以方便的下載COCO,ImageNet,CIFCAR等常用的資料集。程式碼如下:
transform = transforms.ToTensor() train_dataset = datasets.MNIST(root ="./data",
train =True,
transform = transform,
download =True)
test_dataset = datasets.MNIST(root ="./data",
train
transform = transform,
download =True)
train_data_loader = torch.utils.data.DataLoader(dataset = train_dataset,
batch_size=64,
shuffle =True)
test_data_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=64,
shuffle =True)
通過這幾段程式碼我們就實現了資料的載入和裝載。其中batch_size指定我們每次裝載的資料個數,這裡使用的值是64即我們每次裝載到模型中的圖片個數是64張。shuffle設定為True表明我們裝載到模型中的輸入資料是被隨機打亂順序的。
如何使用PyTorch搭建一個簡易的深度學習模型
定義好了資料載入和裝載的方法之後,我們就可以開始搭建深度學習模型,這裡使用卷積層、最大池化層和全連線層來搭建一個簡易的卷積神經網路模型,程式碼如下:
classModel(torch.nn.Module):
def __init__(self):
super(Model,self).__init__()
self.Conv=torch.nn.Sequential(
torch.nn.Conv2d(1,28,kernel_size=3,padding=1,stride=1),
torch.nn.Conv2d(28,64,kernel_size=3,padding=1,stride=1), torch.nn.MaxPool2d(kernel_size=2, stride=2), torch.nn.Conv2d(64,64,kernel_size=3,padding=1,stride=1), torch.nn.Conv2d(64,64,kernel_size=3,padding=1,stride=1), torch.nn.MaxPool2d(kernel_size=2, stride=2)
)
self.Dense= torch.nn.Linear(7*7*64,10)
def forward(self, input):
x =self.Conv(input)
x = x.view(-1,7*7*64)
x =self.Dense(x)
return x ```
模型使用的是4層卷積、2層池化和1層全連線的卷積神經網路結構。雖然結構簡單,但是對於處理手寫數字識別問題這個模型已經綽綽有餘了。搭建好模型之後我們可以對模型進行列印輸出,檢視具體的模型細節,如下:
model =Model()
model = model.cuda()
print(model)
輸出的結果為:
Model(
(Conv):Sequential(
(0):Conv2d(1,28, kernel_size=(3,3), stride=(1,1), padding=(1,1))
(1):Conv2d(28,64, kernel_size=(3,3), stride=(1,1), padding=(1,1))
(2):MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(3):Conv2d(64,64, kernel_size=(3,3), stride=(1,1), padding=(1,1))
(4):Conv2d(64,64, kernel_size=(3,3), stride=(1,1), padding=(1,1))
(5):MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(Dense):Linear(in_features=3136, out_features=10, bias=True)
)
如何優化和訓練我們搭建好的模型
模型已經搭建好了,資料的裝載方式也已經完成了定義,下面就差最後一步了,對我們的模型進行訓練並優化模型內部的引數。程式碼如下: