PyTorch基礎入門六:PyTorch搭建卷積神經網路實現MNIST手寫數字識別
1)卷積神經網路(CNN)簡介
關於什麼是卷積神經網路(CNN),請自行查閱資料進行學習。如果是初學者,這裡推薦一下臺灣的李巨集毅的深度學習課程。連結就不給了,這些資料網站上隨處可見。
值得一提的是,CNN雖然在影象處理的領域具有不可阻擋的勢頭,但是它絕對不僅僅只能用來影象處理領域,大家熟知的alphaGo下圍棋也可以通過CNN的結構進行處理,因為下圍棋與影象有著相似之處,所以說,CNN提供給我們的是一種處理問題的思想,有學者歸納出了可以用CNN解決的問題所具備的三個性質:
-
區域性性
對於一張圖片而言,需要檢測圖片中的特徵來決定圖片的類別,通常情況下這些特徵都不是由整張圖片決定的,而是由一些區域性的區域決定的。例如在某張圖片中的某個區域性檢測出了鳥喙,那麼基本可以判定圖片中有鳥這種動物。
-
相同性
對於不同的圖片,它們具有同樣的特徵,這些特徵會出現在圖片的不同位置,也就是說可以用同樣的檢測模式去檢測不同圖片的相同特徵,只不過這些特徵處於圖片中不同的位置,但是特徵檢測所做的操作幾乎一樣。例如在不同的圖片中,雖然鳥喙處於不同的位置,但是我們可以用相同的模式去檢測。
-
不變性
對於一張圖片,如果我們進行下采樣,那麼圖片的性質基本保持不變。
2)PyTorch中的卷積神經網路
簡要介紹一下PyTorch中卷積神經網路中用到的一些方法。
-
卷積層:nn.Conv2d()
其引數如下:
引數· | 含義 |
in_channels | 輸入訊號的通道數. |
out_channels | 卷積後輸出結果的通道數. |
kernel_size | 卷積核的形狀. 例如kernel_size=(3, 2)表示3X2的卷積核,如果寬和高相同,可以只用一個數字表示 |
stride | 卷積每次移動的步長, 預設為1. |
padding | 處理邊界時填充0的數量, 預設為0(不填充). |
dilation | 取樣間隔數量, 預設為1, 無間隔取樣. |
groups | 輸入與輸出通道的分組數量. 當不為1時, 預設為1(全連線). |
bias | 為 True 時, 新增偏置. |
當然,這麼多引數有一些是不常用的,讀者只需要在實踐中慢慢體會一些常用的即可,其他引數需要將理論打紮實之後去官網查閱。
-
池化層:nn.MaxPool2d()
其引數如下:
引數 | 含義 |
kernel_size | 最大池化操作時的視窗大小 |
stride | 最大池化操作時視窗移動的步長, 預設值是 kernel_size |
padding | 輸入的每條邊隱式補0的數量 |
dilation | 用於控制視窗中元素的步長的引數 |
return_indices | 如果等於 True, 在返回 max pooling 結果的同時返回最大值的索引 這在之後的 Unpooling 時很有用 |
ceil_mode | 如果等於 True, 在計算輸出大小時,將採用向上取整來代替預設的向下取整的方式 |
3)實現MNIST手寫數字識別
一共定義了五層,其中兩層卷積層,兩層池化層,最後一層為FC層進行分類輸出。其網路結構如下:
中間一行表示當前資料塊的維度,第一個維度為深度,後面兩個為寬度和高度。輸入資料為灰度圖,所以深度為1,圖片畫素為28*28的圖片,後面經過卷積,池化,會發現深度不斷加深,而寬度和高度會逐漸減少,因此,最後CNN處理過的圖片只是一個區域性的圖片,換句話說,計算機在進行CNN對圖片進行識別的時候,它通過觀察圖片區域性的資訊來進行分類的,這一點和我們通過人眼來觀察圖片進行分類是不一樣的。
下面是CNN網路的程式碼實現:
# !/usr/bin/python
# coding: utf8
# @Time : 2018-08-05 19:22
# @Author : Liam
# @Email : [email protected]
# @Software: PyCharm
# .::::.
# .::::::::.
# :::::::::::
# ..:::::::::::'
# '::::::::::::'
# .::::::::::
# '::::::::::::::..
# ..::::::::::::.
# ``::::::::::::::::
# ::::``:::::::::' .:::.
# ::::' ':::::' .::::::::.
# .::::' :::: .:::::::'::::.
# .:::' ::::: .:::::::::' ':::::.
# .::' :::::.:::::::::' ':::::.
# .::' ::::::::::::::' ``::::.
# ...::: ::::::::::::' ``::.
# ```` ':. ':::::::::' ::::..
# '.:::::' ':'````..
# 美女保佑 永無BUG
from torch import nn
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(1, 25, kernel_size=3),
nn.BatchNorm2d(25),
nn.ReLU(inplace=True)
)
self.layer2 = nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.layer3 = nn.Sequential(
nn.Conv2d(25, 50, kernel_size=3),
nn.BatchNorm2d(50),
nn.ReLU(inplace=True)
)
self.layer4 = nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.fc = nn.Sequential(
nn.Linear(50 * 5 * 5, 1024),
nn.ReLU(inplace=True),
nn.Linear(1024, 128),
nn.ReLU(inplace=True),
nn.Linear(128, 10)
)
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
然後利用上述模型進行處理,其處理的方法和上一篇博文中的方法是一樣的,這裡不再贅述。
可以看到處理結果比上一次好多了: