1. 程式人生 > >基於PyTorch的LSTM長短時記憶網路實現MNIST手寫數字

基於PyTorch的LSTM長短時記憶網路實現MNIST手寫數字

本篇部落格主要介紹在PyTorch框架下,基於LSTM實現手寫數字的識別。

在介紹LSTM長短時記憶網路之前,我先介紹一下RNN(recurrent neural network)迴圈神經網路.

RNN是一種用來處理序列資料的神經網路,序列資料包括我們說話的語音、一段文字等等。它的出現是為了讓網路自己有記憶能力,每個網路模組把資訊傳給下一個模組,它的網路結構如下:


對於輸入的一段序列資料(X1,X2,X3,X4……)給出預測的結果(Y1,Y2,Y3,Y4……),如果是一個文字情感分類問題,那麼將前幾個Y值抹去,剩下最後一個Y,即是一段文字的預測的情感分類結果。

當進行梯度下降法更新引數時,RNN會出現梯度消失或者梯度下降的問題。當每層的權值W小於1,那麼誤差傳到最開始,結果接近於0,梯度消失;當每層的權值W大於1,那麼誤差傳到第一層,結果會變得無窮大,梯度爆炸。

為了避免這個問題,我們引出了LSTM長短時記憶網路,該網路主要用來延緩記憶衰退。

LSTM的網路結構如下:


它將前一個時刻的輸出Y也傳給了下一個時刻,並且中間單元加了一個判斷資訊是否有用的功能,這樣的話,就可以將讓網路自己選擇要記憶有用的資訊,遺忘無用的資訊。主要通過三個門來實現---輸入門、遺忘門和輸出門。

在識別MNIST手寫數字時,我們可以把28*28畫素的一幅圖片,按行輸入,每一行對應一個時刻,這樣有28個X輸入,有1個Y輸出,將每行的畫素值類比成一個序列資料。

以下程式碼測試了在測試資料的準確率,以及取前10個測試資料的結果。

import torch
from torch import nn
from torch.autograd import Variable
import torchvision.datasets as dsets
import torch.utils.data as Data
import matplotlib.pyplot as plt
import torchvision

torch.manual_seed(1)

EPOCH = 1
BATCH_SIZE = 64
TIME_STEP = 28
INPUT_SIZE = 28
LR = 0.01
DOWNLOAD_MNIST = False

train_data = dsets.MNIST(
    root = './mnist',
    train = True,
    transform = torchvision.transforms.ToTensor(),
    download = DOWNLOAD_MNIST,
)

test_data = torchvision.datasets.MNIST(root='./mnist',train=False)

train_loader = Data.DataLoader(dataset=train_data,batch_size=BATCH_SIZE,shuffle=True)

test_x = Variable(torch.unsqueeze(test_data.test_data,dim=1),volatile=True).type(torch.FloatTensor)/255
test_y = test_data.test_labels

class RNN(nn.Module):
    def __init__(self):
        super(RNN,self).__init__()

        self.rnn = nn.LSTM(
            input_size=28,
            hidden_size=64,
            num_layers=1,
            batch_first=True,
        )

        self.out = nn.Linear(64,10)

    def forward(self,x):
        r_out, (h_n, h_c) = self.rnn(x, None)

        out = self.out(r_out[:,-1,:])
        return out

rnn = RNN()
print(rnn)

optimizer = torch.optim.Adam(rnn.parameters(),lr=LR)
loss_func = nn.CrossEntropyLoss()

for epoch in range(EPOCH):
    for step,(x,y) in enumerate(train_loader):
        b_x = Variable(x.view(-1,28,28))
        b_y = Variable(y)

        output = rnn(b_x)
        loss = loss_func(output,b_y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if step%50 == 0:
            test_output = rnn(test_x.view(-1,28,28))
            pred_y = torch.max(test_output,1)[1].data.numpy().squeeze()
            accuracy = sum(pred_y == test_y)/float(test_y.size(0))
            print('Epoch: ',epoch, '| train loss:%.4f' %loss.data[0],'| test accuracy:%.2f' %accuracy)

test_output = rnn(test_x[:10].view(-1,28,28))
pred_y = torch.max(test_output,1)[1].data.numpy().squeeze()
print(pred_y,'prediction number')
print(test_y[:10].numpy(),'real number')

執行結果:


相關推薦

基於PyTorch的LSTM長短記憶網路實現MNIST數字

本篇部落格主要介紹在PyTorch框架下,基於LSTM實現手寫數字的識別。在介紹LSTM長短時記憶網路之前,我先介紹一下RNN(recurrent neural network)迴圈神經網路.RNN是一種用來處理序列資料的神經網路,序列資料包括我們說話的語音、一段文字等等。它

神經網路實現Mnist數字識別筆記

目錄 1.Mnist手寫數字識別介紹         Mnist手寫數字識別是Kaggle上一個很經典的機器學習資料集,裡邊包括55000張訓練資料和10000張圖片的測試資料,每張圖片大小為28*28畫素的單通圖片。該任務為通過機器學習來識別圖片中的

PyTorch基礎入門六:PyTorch搭建卷積神經網路實現MNIST數字識別

1)卷積神經網路(CNN)簡介 關於什麼是卷積神經網路(CNN),請自行查閱資料進行學習。如果是初學者,這裡推薦一下臺灣的李巨集毅的深度學習課程。連結就不給了,這些資料網站上隨處可見。 值得一提的是,CNN雖然在影象處理的領域具有不可阻擋的勢頭,但是它絕對不僅僅只能用來影

深度學習入門——利用卷積神經網路實現MNIST數字識別

MNIST(Modified National Institute of Standards and Technology)資料庫是一個大型手寫數字資料庫,通常用於訓練各種影象處理系統。該資料庫還廣泛用於機器學習領域的培訓和測試。它是通過重新打亂來自NIST原始資料集的樣本而

深度學習:tensorflow入門:卷積神經網路實現MNIST字型識別

程式碼中./data/mnist/input_data/為真實MNIST資料集的路徑 import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data # 定義

Tensorflow卷積神經網路實現MNIST資料集識別

模型建的不好,最終只有85%左右的準確率,後面繼續改進吧   #卷積神經網路API     卷積層:tf.nn.conv2d(input,    #輸入張量,具有[batch, height, width, chann

Tensorflow 實戰Google深度學習框架——學習筆記(六)LeNet-5網路實現MNIST數字集識別

使用LeNet-5模型實現MNIST手寫數字識別,其神經網路架構如下: 一、詳細介紹LeNet-5模型每一層的結構 第一層,卷積層 這一層輸入原始的影象畫素,接受的輸入層大小為32*32*1,第一個卷積層過濾器尺寸為5*5,共6個,不使用全0填

全連線神經網路實現MNIST資料集識別

有目錄,內容大部分從官方教程copy,黑體加粗為對官方教程的補充 TensorFlow,pytorch,cuda,cudnn,anaconda安裝 版本對應關係 Version Python version Compiler Build tools cu

tensorflow-GPU 一步步搭建網路實現MNIST識別

1. MNIST資料集的匯入 這裡介紹一下MNIST,MNIST是在機器學習領域中的一個經典問題。該問題解決的是把28x28畫素的灰度手寫數字圖片識別為相應的數字,其中數字的範圍從0到9. 首先我們要匯入MNIST資料集,這裡需要用到一個input_data.py檔案,在

基於PyTorch的CNN卷積神經網路識別MNIST數字

本篇部落格主要介紹基於PyTorch深度學習框架來實現MNIST經典的手寫數字,運用CNN卷積神經網路。MNIST資料集來自美國國家標準與技術研究所,其中訓練資料有60000張,測試資料有10000張,每張圖片的大小是28*28畫素我們可以基於PyTorch直接下載該資料集。

使用LeNet-5實現mnist數字分類識別 TensorFlow

TensorFlow的學習材料很多,但很少有講得特別詳細,讓小白一看就懂的。我自己總結了cnn實現mnist分類識別的方法,希望能給TensorFlow初學者一些幫助,實測在python3下可以執行。 # -*- coding: utf-8 -*- # 使用LeNet-5實

深度學習之PyTorch —— CNN實現MNIST數字分類

​# 運用CNN分析MNIST手寫數字分類 import torch import numpy as np from torch.utils.data import DataLoader from torchvision.datasets import mnist fro

CNN實現MNIST數字識別

關鍵詞:CNN、TensorFlow、卷積、池化、特徵圖 一. 前言 本文用TensorFlow實現了CNN(卷積神經網路)的經典結構LeNet-5, 具體CNN的LeNet-5模型原理見《深度學習(四)卷積神經網路入門學習(1)》,講得還是比較清楚的。

OpenCV機器學習:SVM分類器實現MNIST數字識別

0. 開發環境 最近機器學習隨著AI人工智慧的興起越來越火,博主想找一些ML的庫來練手。突然想起之前在看Opencv的doc時發現有ML的component,於是心血來潮就開始寫程式碼試試。話不多說,直接進正題。 以下我的開發環境配置: -Windows7

深度學習系列——AlxeNet實現MNIST數字體識別

    本文實現AlexNet,用於識別MNIST手寫數字體。所有程式碼的框架基於tensorflow。看了幾篇論文的原始碼之後,覺得tensorflow 確實很難,學習程式設計還是靠實踐。這篇部落格留著給自己以及學習深度學習道路上的小夥伴們一些參考吧,希望能對大家有所幫助!

matlab神經網路實現識別數字

實驗說明 一直想自己寫一個神經網路來實現手寫數字的識別,而不是套用別人的框架。恰巧前幾天,有幸從同學那拿到5000張已經貼好標籤的手寫數字圖片,於是我就嘗試用matlab寫一個網路。 實驗資料:5000張手寫數字圖片(.jpg),圖片命名為1.jpg,2.

使用LSTM實現mnist數字分類識別 TensorFlow

RNN做影象識別原理:MNIST資料集中一張圖片資料包含28*28的畫素點。RNN是將一張圖片資料的一行作為一個向量總體輸入一個X中。也就是說,RNN有28個輸入X,一個輸入X有28個畫素點。 輸出最後一個結果做為預測值。   TensorFlow入門學習程式碼: # -

python神經網路實現識別數字

實驗說明 一直想自己寫一個神經網路來實現手寫數字的識別,而不是套用別人的框架。恰巧前幾天,有幸從同學那拿到5000張已經貼好標籤的手寫數字圖片,於是我就嘗試用matlab寫一個網路。 實驗資料:5000張手寫數字圖片(.jpg),圖片命名為

【MATLAB】BP神經網路識別MNIST數字

一.Summary本文運用BP神經網路對MNIST手寫數字字元進行識別。BP神經網路是通過樣本以及期望輸出不斷調整權值以達到訓練的目的的演算法。本文采用三層BP神經網路對MNIST字元進行訓練,然後對提供的資料進行測試。二.Content1. 分析BP神經網路的基本原理:通過

運用tensorflow全連線神經網路進行MNIST數字影象識別

本文記錄tensorflow搭建簡單神經網路,並進行模組化處理,目的在於總結並提取簡單神經網路搭建的基本思想和方法,提煉核心結構和元素,從而能夠移植到日後深入學習中去。 1 模組提煉 1.1 template_forward.py