1. 程式人生 > >Pytorch 快速入門(六)MNIST 手寫數字識別之CNN

Pytorch 快速入門(六)MNIST 手寫數字識別之CNN

接觸過深度學習的朋友對MNIST資料集肯定不陌生。基本上算是玩神經網路裡的“hello,world!”

本節基於MNIST資料集,實現CNN學習過程。

下面是一個 CNN 最後一層的學習過程, 我們先視覺化看看:

CNN 卷積神經網路

MNIST手寫資料 

首先是資料集下載,為了看看資料集裡究竟是長什麼樣子的,我也作了展示:

# library
# standard library
import os

# third-party library
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt

torch.manual_seed(1)    # reproducible

# Hyper Parameters
EPOCH = 1               # train the training data n times, to save time, we just train 1 epoch
BATCH_SIZE = 50
LR = 0.001              # learning rate
DOWNLOAD_MNIST = False


# Mnist digits dataset
if not(os.path.exists('./mnist/')) or not os.listdir('./mnist/'):
    # not mnist dir or mnist is empyt dir
    DOWNLOAD_MNIST = True

train_data = torchvision.datasets.MNIST(
    root='./mnist/',
    train=True,                                     # this is training data
    transform=torchvision.transforms.ToTensor(),    # Converts a PIL.Image or numpy.ndarray to
                                                    # torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0]
    download=DOWNLOAD_MNIST,
)

# plot one example
print(train_data.train_data.size())                 # (60000, 28, 28)
print(train_data.train_labels.size())               # (60000)
plt.imshow(train_data.train_data[0].numpy(), cmap='gray')
plt.title('%i' % train_data.train_labels[0])
plt.show()

結果如下:

首先看下載後的資料集目錄


展示example


黑色的地方的值都是0, 白色的地方值大於0.

同樣, 我們除了訓練資料, 還給一些測試資料, 測試看看它有沒有訓練好

CNN模型 

和以前一樣, 我們用一個 class 來建立 CNN 模型. 這個 CNN 整體流程是 卷積(Conv2d) -> 激勵函式(ReLU) -> 池化, 向下取樣 (MaxPooling) -> 再來一遍 -> 展平多維的卷積成的特徵圖 -> 接入全連線層 (Linear) -> 輸出

# Data Loader for easy mini-batch return in training, the image batch shape will be (50, 1, 28, 28)
train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)

# convert test data into Variable, pick 2000 samples to speed up testing
# # 為了節約時間, 我們測試時只測試前2000個
test_data = torchvision.datasets.MNIST(root='./mnist/', train=False)
test_x = Variable(torch.unsqueeze(test_data.test_data, dim=1), volatile=True).type(torch.FloatTensor)[:2000]/255.   # shape from (2000, 28, 28) to (2000, 1, 28, 28), value in range(0,1)
test_y = test_data.test_labels[:2000]


class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Sequential(         # input shape (1, 28, 28)
            nn.Conv2d(
                in_channels=1,              # input height
                out_channels=16,            # n_filters
                kernel_size=5,              # filter size
                stride=1,                   # filter movement/step
                padding=2,                  # if want same width and length of this image after con2d, padding=(kernel_size-1)/2 if stride=1
            ),                              # output shape (16, 28, 28)
            nn.ReLU(),                      # activation
            nn.MaxPool2d(kernel_size=2),    # choose max value in 2x2 area, output shape (16, 14, 14)
        )
        self.conv2 = nn.Sequential(         # input shape (16, 14, 14)
            nn.Conv2d(16, 32, 5, 1, 2),     # output shape (32, 14, 14)
            nn.ReLU(),                      # activation
            nn.MaxPool2d(2),                # output shape (32, 7, 7)
        )
        self.out = nn.Linear(32 * 7 * 7, 10)   # fully connected layer, output 10 classes

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)           # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
        output = self.out(x)
        return output, x    # return x for visualization


cnn = CNN()
print(cnn)  # net architecture

結構列印如下:

訓練 

下面我們開始訓練, 將 x y 都用 Variable 包起來, 然後放入 cnn 中計算 output, 最後再計算誤差. 


optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)   # optimize all cnn parameters
loss_func = nn.CrossEntropyLoss()                       # the target label is not one-hotted

# following function (plot_with_labels) is for visualization, can be ignored if not interested
from matplotlib import cm
try: from sklearn.manifold import TSNE; HAS_SK = True
except: HAS_SK = False; print('Please install sklearn for layer visualization')
def plot_with_labels(lowDWeights, labels):
    plt.cla()
    X, Y = lowDWeights[:, 0], lowDWeights[:, 1]
    for x, y, s in zip(X, Y, labels):
        c = cm.rainbow(int(255 * s / 9)); plt.text(x, y, s, backgroundcolor=c, fontsize=9)
    plt.xlim(X.min(), X.max()); plt.ylim(Y.min(), Y.max()); plt.title('Visualize last layer'); plt.show(); plt.pause(0.01)

plt.ion()
# training and testing
for epoch in range(EPOCH):
    for step, (x, y) in enumerate(train_loader):   # gives batch data, normalize x when iterate train_loader
        b_x = Variable(x)   # batch x
        b_y = Variable(y)   # batch y

        output = cnn(b_x)[0]               # cnn output
        loss = loss_func(output, b_y)   # cross entropy loss
        optimizer.zero_grad()           # clear gradients for this training step
        loss.backward()                 # backpropagation, compute gradients
        optimizer.step()                # apply gradients

        if step % 50 == 0:
            test_output, last_layer = cnn(test_x)
            pred_y = torch.max(test_output, 1)[1].data.squeeze()
            accuracy = sum(pred_y == test_y) / float(test_y.size(0))
            print('Epoch: ', epoch, '| train loss: %.4f' % loss.data[0], '| test accuracy: %.2f' % accuracy)
            if HAS_SK:
                # Visualization of trained flatten layer (T-SNE)
                tsne = TSNE(perplexity=30, n_components=2, init='pca', n_iter=5000)
                plot_only = 500
                low_dim_embs = tsne.fit_transform(last_layer.data.numpy()[:plot_only, :])
                labels = test_y.numpy()[:plot_only]
                plot_with_labels(low_dim_embs, labels)
plt.ioff()

訓練過程展示:




最後我們再來取10個數據, 看看預測的值到底對不對:

預測結果


視覺化訓練

     因為視覺化可以幫助理解, 所以還是有必要提一下. 視覺化的程式碼主要是用 matplotlib 和 sklearn 來完成的, 因為其中我們用到了 T-SNE 的降維手段, 將高維的 CNN 最後一層輸出結果視覺化, 也就是 CNN forward 程式碼中的 x = x.view(x.size(0), -1) 這一個結果.

相關推薦

Pytorch 快速入門MNIST 數字識別CNN

接觸過深度學習的朋友對MNIST資料集肯定不陌生。基本上算是玩神經網路裡的“hello,world!”本節基於MNIST資料集,實現CNN學習過程。下面是一個 CNN 最後一層的學習過程, 我們先視覺化看看:MNIST手寫資料 首先是資料集下載,為了看看資料集裡究竟是長什麼樣

TensorFlow程式碼實現[MNIST數字識別]

最簡單的神經網路結構: 資料來源準備:資料在之前的文章中分析過了 在這裡我們就構造一層神經網路: 前提準備: 引數: train images:因為圖片是28*28的個數,換算成一維陣列就是784,因此我們定義x = tf.placeholder(tf

Mnist數字識別CNN實現

Mnist手寫數字識別之CNN實現 最近有點閒,想整一下機器學習,本以為自己程式設計還不錯,想想機器學習也不難,結果被自己啪啪啪的打臉,還疼的不行。 廢話不多說,開始搞事情。 本部落格的主要內容是:通過TF一步一步用卷積神經網路(CNN)實現手寫Mnist數字識別 如果你

深度學習PyTorch實戰3——實戰數字識別

  上一節,我們已經學會了基於PyTorch深度學習框架高效,快捷的搭建一個神經網路,並對模型進行訓練和對引數進行優化的方法,接下來讓我們牛刀小試,基於PyTorch框架使用神經網路來解決一個關於手寫數字識別的計算機視覺問題,評價我們搭建的模型的標準是它是否能準確的對手寫數字圖片進行識別。   

Tensorflow深度學習筆記--BPNN數字識別視覺化

資料集:MNIST 啟用函式:Relu 損失函式:交叉熵 Optimizer:AdamOptimizer 視覺化工具:tensorboad 迭代21epoch,accuracy結果如下: Iter 16,Testing Accuracy:

機器學習--k-近鄰演算法kNN實現數字識別

這裡的手寫數字以0,1的形式儲存在文字檔案中,大小是32x32.目錄trainingDigits有1934個樣本。0-9每個數字大約有200個樣本,命名規則如下: 下劃線前的數字代表是樣本0-9的

基於k近鄰KNN數字識別

作者:faaronzheng 轉載請註明出處! 最近再看Machine Learning in Action. k近鄰演算法這一章節提供了不少例子,本著Talk is cheap的原則,我們用手寫數字識別來實際測試一下。 簡單的介紹一下k近鄰演算法(KNN):給定測試樣本

Caffe學習筆記():mnist數字識別訓練例項

一、前言    深度學習的一個簡單例項就是mnist手寫數字識別,只要這個例子弄明白,其它的內容就可以舉一反三了。之前的內容如有遺忘,可以進行回顧。 二、準備資料    資料集可以直接從我的github下載,包括資料集和程式碼,使用如下指令:

Spring Boot快速入門:thymeleaf

return 之前 err static 默認 示例 圖片資源 css 官網 原文地址:https://lierabbit.cn/articles/8 靜態資源 在我們開發Web應用的時候,需要引用大量的js、css、圖片等靜態資源。 Spring Boot的默認位置是re

區塊鏈快速入門——區塊鏈密碼學與安全相關技術

ocs 國家 發出 .com 少包 signature payment 有效期 保護 區塊鏈快速入門(六)——區塊鏈密碼學與安全相關技術 一、區塊鏈密碼安全技術簡介 區塊鏈和分布式賬本中大量使用了密碼學和安全技術的最新成果,特別是身份認證和隱私保護相關技術。區塊鏈使用了包括

06 React快速入門——使用css樣式

      在react中元件使用樣式主要有兩種方式:內聯式和外聯式。下面分別介紹兩種方式:       一、內聯式       使用內聯樣式,和html中差不多,不多在此處需要注意的是,通過sty

java多執行緒快速入門

多執行緒應用例項(批量傳送簡訊) 1、建立實體類 package com.cppdy; public class UserEntity { private int id; private String name; public int g

java多線程快速入門

static code msg close play int vat one bsp 多線程應用實例(批量發送短信) 1、創建實體類 package com.cppdy; public class UserEntity { private int i

基於tensorflow的MNIST數字識別--入門

一、本文的意義       因為谷歌官方其實已經寫了MNIST入門和深入兩篇教程了,那我寫這些文章又是為什麼呢,只是抄襲?那倒並不是,更準確的說應該是筆記吧,然後用更通俗的語言來解釋,並且補充

PyTorch基礎入門PyTorch搭建卷積神經網路實現MNIST數字識別

1)卷積神經網路(CNN)簡介 關於什麼是卷積神經網路(CNN),請自行查閱資料進行學習。如果是初學者,這裡推薦一下臺灣的李巨集毅的深度學習課程。連結就不給了,這些資料網站上隨處可見。 值得一提的是,CNN雖然在影象處理的領域具有不可阻擋的勢頭,但是它絕對不僅僅只能用來影

Pytorch 快速入門載入預訓練模型初始化網路引數

在預訓練網路的基礎上,修改部分層得到自己的網路,通常我們需要解決的問題包括: 1. 從預訓練的模型載入引數 2. 對新網路兩部分設定不同的學習率,主要訓練自己新增的層 PyTorch提供的預訓練模型PyTorch定義了幾個常用模型,並且提供了預訓練版本:AlexNet: Al

mnist數字識別——深度學習入門專案tensorflow+keras+Sequential模型

前言 今天記錄一下深度學習的另外一個入門專案——《mnist資料集手寫數字識別》,這是一個入門必備的學習案例,主要使用了tensorflow下的keras網路結構的Sequential模型,常用層的Dense全連線層、Activation啟用層和Reshape層。還有其他方法訓練手寫數

TensorflowMNIST數字識別:分類問題1

一、MNIST資料集讀取 one hot 獨熱編碼獨熱編碼是一種稀疏向量,其中:一個向量設為1,其他元素均設為0.獨熱編碼常用於表示擁有有限個可能值的字串或識別符號優點:   1、將離散特徵的取值擴充套件到了歐式空間,離散特徵的某個取值就對應歐式空間的某個點    2、機器學習演算法中,

TensorflowMNIST數字識別:分類問題2

整體程式碼: #資料讀取 import tensorflow as tf import matplotlib.pyplot as plt import numpy as np from tensorflow.examples.tutorials.mnist import input_data mnis

TensorFlow學習筆記1—— MNIST識別

1、初步學習 資料處理 xs:60000張圖片,28*28大小,將所有畫素點按一列排列,資料集變為了[60000, 784]的二維矩陣。 ys:60000張圖片,每個圖片有一個標籤標識圖片中數字,採用one-hot向量,資料集變為[60000, 10]的二維矩陣。 softm