1. 程式人生 > >基於PyTorch的CNN卷積神經網路識別MNIST手寫數字

基於PyTorch的CNN卷積神經網路識別MNIST手寫數字

本篇部落格主要介紹基於PyTorch深度學習框架來實現MNIST經典的手寫數字,運用CNN卷積神經網路。

MNIST資料集來自美國國家標準與技術研究所,其中訓練資料有60000張,測試資料有10000張,每張圖片的大小是28*28畫素

我們可以基於PyTorch直接下載該資料集。該識別程式先使用一層卷積層(卷積核數量16,卷積核大小5*5,步長為1,允許邊緣擴充),緊接著啟用層使用ReLU函式,之後緊跟著一個max pooling層,大小是2*2. 之後再設定同樣的卷積層(卷積核為2種)、啟用層、降取樣層,最後跟一個全連線層,輸出為10個神經元表示有十類,分別是0,1,2,3……9

在各層網路中資料的規模變化如下:

初始情況:1*28*28             (表示只有一個顏色通道,一張圖片畫素大小是28*28)

第一層卷積之後:16*28*28  (使用了16種卷積核,步長為1,兩邊各擴充1格,大小不變28*28)

第一層max pooling之後: 16*14*14 (使用2*2的大小進行降取樣,圖的大小迅速縮小一半)

第二層卷積之後:32*14*14

第二層max pooling之後:32*7*7

全連線層:10  (共有十個類別)

程式碼展示:

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt

torch.manual_seed(1)   #結果可復現

#設定超引數
EPOCH = 5
BATCH_SIZE= 50
LR = 0.01
DOWNLOAD_MNIST = True #是否下載資料

train_data = torchvision.datasets.MNIST(
    root = './mnist/', #儲存的位置
    train = True,   #表示是訓練資料
    transform=torchvision.transforms.ToTensor(),
    download = DOWNLOAD_MNIST,
)

test_data = torchvision.datasets.MNIST(root='./mnist/',train=False)

train_loader = Data.DataLoader(dataset=train_data,batch_size=BATCH_SIZE,shuffle=True)

test_x = Variable(torch.unsqueeze(test_data.test_data,dim=1),volatile=True).type(torch.FloatTensor)[:2000]/255
test_y = test_data.test_labels[:2000]

class CNN(nn.Module):
    def __init__(self):
        super(CNN,self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(
                in_channels=1,    #高度是1,就一個通道
                out_channels=16,  #卷積核數量
                kernel_size=5,    #卷積核大小
                stride=1,         #設定步長為1
                padding=2,        #邊緣擴充兩格
            ),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(16,32,5,1,2),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.out = nn.Linear(32*7*7,10)

    def forward(self,x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0),-1)  #展平多維的卷積圖
        output = self.out(x)
        return output

cnn = CNN()
print(cnn)

optimizer = torch.optim.Adam(cnn.parameters(),lr=LR)
loss_func = nn.CrossEntropyLoss()

for epoch in range(EPOCH):
    for step,(x,y) in enumerate(train_loader):
        b_x = Variable(x)
        b_y = Variable(y)

        output = cnn(b_x)
        loss = loss_func(output,b_y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    test_output=cnn(test_x)
    pred_y = torch.max(test_output,1)[1].data.squeeze()
    accuracy = sum(pred_y==test_y) / float(test_y.size(0))
    print('Epoch: ',epoch, '| train loss: %.4f' %loss.data[0],'| test accuracy: %.2f' %accuracy)

test_output = cnn(test_x[:10])
pre_y = torch.max(test_output,1)[1].data.numpy().squeeze()
print(pre_y,'prediction number')
print(test_y[:10].numpy(),'real number')

輸出CNN卷積網路的形態:


輸出前2000個測試資料的準確率:


輸出前10個測試資料的預測資料和真實值: