1. 程式人生 > >pytorch學習:MNIST手寫數字識別程式碼

pytorch學習:MNIST手寫數字識別程式碼

# -*- coding: utf-8 -*-
"""
Created on Mon Sep  3 08:38:27 2018

@author: www
"""

import torch
from torch import nn
from torchvision.datasets import MNIST
import numpy as np
from torch.autograd import Variable

train_set = MNIST('./data', train=True, download=True)
test_set = MNIST('./data', train=False, download=True)

#觀察一下資料
a_data, a_label = train_set[0]
print(a_data)
print(a_label)

#這裡的讀入的資料是 PIL 庫中的格式,我們可以非常方便地將其轉換為 numpy array
a_data = np.array(a_data, dtype='float32')
print(a_data)

#對於神經網路,我們第一層的輸入就是 28 x 28 = 784,所以必須將得到的資料
#我們做一個變換,使用 reshape 將他們拉平成一個一維向量
def get_data(x):
     x = np.array(x, dtype='float32')/255
     x = (x-0.5)/0.5
     x = x.reshape((-1,))
     x = torch.from_numpy(x)
     return x
train_set = MNIST('./data', train=True, transform=get_data, download=True)
test_set = MNIST('./data', train=False, transform=get_data, download=True)

a, a_label = train_set[0]
print(a.shape)
print(a_label)

#使用pytorch自帶的DataLoader定義一個數據迭代器
#使用這樣的資料迭代器是非常有必要的,如果資料量太大,就無法一次將他們全部讀入記憶體,
#所以需要使用 python 迭代器,每次生成一個批次的資料
from torch.utils.data import DataLoader
train_data = DataLoader(train_set, batch_size=64, shuffle=True)
test_data = DataLoader(test_set, batch_size=128, shuffle=False)

a, a_label = next(iter(train_data))

#打印出一個批次的資料
print(a.shape)
print(a_label.shape)

#使用Sequential定義4層神經網路
net = nn.Sequential(
     nn.Linear(784, 400),
     nn.ReLU(),
     nn.Linear(400, 200),
     nn.ReLU(),
     nn.Linear(200, 100),
     nn.ReLU(),
     nn.Linear(100, 10)
)

#交叉熵在 pytorch 中已經內建了,交叉熵的數值穩定性更差,所以內建的函式已經幫我們解決了這個問題
#定義loss函式
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), 1e-1)

#開始訓練
losses = []
acces = []
eval_losses = []
eval_acces = []

for e in range(20):
     train_loss = 0
     train_acc = 0
     net.train()
     for im, label in train_data:
          im = Variable(im)
          label = Variable(label)
          #前向傳播
          out = net(im)
          loss = criterion(out, label)
          #反向傳播
          optimizer.zero_grad()
          loss.backward()
          optimizer.step()
          #記錄誤差
          train_loss += loss.item()
          #計算分類的準確率
          _, pred = out.max(1)
          num_correct = (pred==label).sum().item()
          acc = num_correct / im.shape[0]
          train_acc += acc
     losses.append(train_loss / len(train_data))
     acces.append(train_acc / len(train_data))
     #在測試集上檢驗效果
     eval_loss = 0
     eval_acc = 0
     net.eval() 
     for im, label in test_data:
          im = Variable(im)
          label = Variable(label)
          out = net(im)
          loss = criterion(out, label)
          #記錄誤差
          eval_loss += loss.item()
          #記錄準確率
          _, pred = out. max(1)
          num_correct = (pred==label).sum().item()
          acc = num_correct / im.shape[0]
          eval_acc += acc
     eval_losses.append(eval_loss / len(test_data))
     eval_acces.append(eval_acc / len(test_data))
     print('epoch:{}, Train Loss:{:.6f}, Train Acc:{:.6f}, Eval Loss:{:.6f}, Eval Acc:{:.6f}'
     .format(e, train_loss/len(train_data), train_acc/len(train_data), 
             eval_loss/len(test_data), eval_acc/len(test_data)))

#最後一次的資料
#epoch:19, Train Loss:0.008779, Train Acc:0.997385, Eval Loss:0.072433, Eval Acc:0.982793

#畫出loss曲線和準確率曲線
import matplotlib.pyplot as plt
plt.title('train loss')
plt.plot(np.arange(len(losses)), losses)

plt.plot(np.arange(len(acces)), acces)
plt.title('train acc')

plt.plot(np.arange(len(eval_losses)), eval_losses)
plt.title('test loss')

plt.plot(np.arange(len(eval_acces)), eval_acces)
plt.title('test acc')












     

相關推薦

pytorch學習MNIST數字識別程式碼

# -*- coding: utf-8 -*- """ Created on Mon Sep 3 08:38:27 2018 @author: www """ import torch from torch import nn from torchvision.data

tensorflow 基礎學習MNIST數字識別

truncate averages val flow one die correct 表示 data MNIST數據集介紹: from tensorflow.examples.tutorials.mnist import input_data # 載入MNIST數據集,

Caffe學習筆記(六)mnist數字識別訓練例項

一、前言    深度學習的一個簡單例項就是mnist手寫數字識別,只要這個例子弄明白,其它的內容就可以舉一反三了。之前的內容如有遺忘,可以進行回顧。 二、準備資料    資料集可以直接從我的github下載,包括資料集和程式碼,使用如下指令:

【AI實戰】訓練第一個AI模型MNIST數字識別模型

在上篇文章中,我們已經把AI的基礎環境搭建好了(見文章:Ubuntu + conda + tensorflow + GPU + pycharm搭建AI基礎環境),接下來將基於tensorflow訓練第一個AI模型:MNIST手寫數字識別模型。 MNIST是一個經典的手寫數字資料集,來自美國國家

tensorflow實戰MNIST數字識別的優化2-代價函式優化,準確率98%

最簡單的tensorflow的手寫識別模型,這一節我們將會介紹其簡單的優化模型。我們會從代價函式,多層感知器,防止過擬合,以及優化器的等幾個方面來介紹優化過程。    1.代價函式的優化:             我們可以這樣將代價函式理解為真實值與預測值的差距,我們神經

Tensorflow案例4Mnist數字識別(線性神經網路)及其侷限性

學習目標 目標 應用matmul實現全連線層的計算 說明準確率的計算 應用softmax_cross_entropy_with_logits實現softamx以及交叉熵損失計算 說明全連線層在神經網路的作用 應用全連

從零開始學caffe(四)mnist數字識別網路結構模型和超引數檔案的原始碼閱讀

下面為網路結構模型 %網路結構模型 name: "LeNet" #網路的名字"LeNet" layer { #定義一個層 name: "mnist" #層的名字"mnist" type:

訓練第一個AI模型MNIST數字識別模型

在上篇文章中,我們已經把AI的基礎環境搭建好了(見文章:Ubuntu + conda + tensorflow + GPU + pycharm搭建AI基礎環境),接下來將基於tensorflow訓練第一個AI模型:MNIST手寫數字識別模型。 MNIST是一個經典的手寫數字資料集,來自美國國家

TensorFlow筆記之一MNIST數字識別

     本人剛剛開始接觸深度學習不久,對於tensorflow的瞭解也有限,想通過tensorflow這個框架來學習深度學習及其優化與識別。現在直接進入主題。     1.手寫識別的介紹:            MNIST手寫識別在機器學習中就像c語言中Hello Wor

深度學習mnist數字識別入門

使用tensorflow框架和python,學習實現簡單的神經網路,並進行調參,程式碼如下:   #! /usr/bin/python # -*- coding:utf-8 -*- """ a simple mnist classifier """ from __fut

OpenCV機器學習SVM分類器實現MNIST數字識別

0. 開發環境 最近機器學習隨著AI人工智慧的興起越來越火,博主想找一些ML的庫來練手。突然想起之前在看Opencv的doc時發現有ML的component,於是心血來潮就開始寫程式碼試試。話不多說,直接進正題。 以下我的開發環境配置: -Windows7

PyTorch基礎入門六PyTorch搭建卷積神經網路實現MNIST數字識別

1)卷積神經網路(CNN)簡介 關於什麼是卷積神經網路(CNN),請自行查閱資料進行學習。如果是初學者,這裡推薦一下臺灣的李巨集毅的深度學習課程。連結就不給了,這些資料網站上隨處可見。 值得一提的是,CNN雖然在影象處理的領域具有不可阻擋的勢頭,但是它絕對不僅僅只能用來影

Tensorflow深度學習之七再談mnist數字識別程式

之前學習的第一個深度學習的程式就是mnist手寫字型的識別,那個時候對於很多概念不是很理解,現在回過頭再看當時的程式碼,理解了很多,現將加了註釋的程式碼貼上,與大家分享。(本人還是在學習Tensorflow的初始階段,如果有什麼地方理解有誤,還請大家不吝指出。)

Tensorflow之MNIST數字識別分類問題(1)

一、MNIST資料集讀取 one hot 獨熱編碼獨熱編碼是一種稀疏向量,其中:一個向量設為1,其他元素均設為0.獨熱編碼常用於表示擁有有限個可能值的字串或識別符號優點:   1、將離散特徵的取值擴充套件到了歐式空間,離散特徵的某個取值就對應歐式空間的某個點    2、機器學習演算法中,

Tensorflow之MNIST數字識別分類問題(2)

整體程式碼: #資料讀取 import tensorflow as tf import matplotlib.pyplot as plt import numpy as np from tensorflow.examples.tutorials.mnist import input_data mnis

Tensorflow案例5CNN演算法-Mnist數字識別

學習目標 目標 應用tf.nn.conv2d實現卷積計算 應用tf.nn.relu實現啟用函式計算 應用tf.nn.max_pool實現池化層的計算 應用卷積神經網路實現影象分類識別 應用

深度學習PyTorch —— CNN實現MNIST數字分類

​# 運用CNN分析MNIST手寫數字分類 import torch import numpy as np from torch.utils.data import DataLoader from torchvision.datasets import mnist fro

TensorFlow實現機器學習的“Hello World”--Mnist數字識別

TensorFlow實現機器學習的“Hello World” 上一篇部落格我們已經說了TensorFlow大概怎麼使用,這次來說說機器學習中特別經典的案例,也相當於是機器學習的“Hello World”,他就是Mnist手寫數字識別,也就是通過訓練機器讓他能看

機器學習-實戰-入門-MNIST數字識別

作者:橘子派 宣告:版權所有,轉載請註明出處,謝謝。 實驗環境: Windows10 Sublime Anaconda 1.6.0 Python3.6 程式碼功能包括: 一.ubyte資料集轉

Pytorch 快速入門(六)MNIST 數字識別之CNN

接觸過深度學習的朋友對MNIST資料集肯定不陌生。基本上算是玩神經網路裡的“hello,world!”本節基於MNIST資料集,實現CNN學習過程。下面是一個 CNN 最後一層的學習過程, 我們先視覺化看看:MNIST手寫資料 首先是資料集下載,為了看看資料集裡究竟是長什麼樣