pytorch學習:MNIST手寫數字識別程式碼
# -*- coding: utf-8 -*- """ Created on Mon Sep 3 08:38:27 2018 @author: www """ import torch from torch import nn from torchvision.datasets import MNIST import numpy as np from torch.autograd import Variable train_set = MNIST('./data', train=True, download=True) test_set = MNIST('./data', train=False, download=True) #觀察一下資料 a_data, a_label = train_set[0] print(a_data) print(a_label) #這裡的讀入的資料是 PIL 庫中的格式,我們可以非常方便地將其轉換為 numpy array a_data = np.array(a_data, dtype='float32') print(a_data) #對於神經網路,我們第一層的輸入就是 28 x 28 = 784,所以必須將得到的資料 #我們做一個變換,使用 reshape 將他們拉平成一個一維向量 def get_data(x): x = np.array(x, dtype='float32')/255 x = (x-0.5)/0.5 x = x.reshape((-1,)) x = torch.from_numpy(x) return x train_set = MNIST('./data', train=True, transform=get_data, download=True) test_set = MNIST('./data', train=False, transform=get_data, download=True) a, a_label = train_set[0] print(a.shape) print(a_label) #使用pytorch自帶的DataLoader定義一個數據迭代器 #使用這樣的資料迭代器是非常有必要的,如果資料量太大,就無法一次將他們全部讀入記憶體, #所以需要使用 python 迭代器,每次生成一個批次的資料 from torch.utils.data import DataLoader train_data = DataLoader(train_set, batch_size=64, shuffle=True) test_data = DataLoader(test_set, batch_size=128, shuffle=False) a, a_label = next(iter(train_data)) #打印出一個批次的資料 print(a.shape) print(a_label.shape) #使用Sequential定義4層神經網路 net = nn.Sequential( nn.Linear(784, 400), nn.ReLU(), nn.Linear(400, 200), nn.ReLU(), nn.Linear(200, 100), nn.ReLU(), nn.Linear(100, 10) ) #交叉熵在 pytorch 中已經內建了,交叉熵的數值穩定性更差,所以內建的函式已經幫我們解決了這個問題 #定義loss函式 criterion = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(net.parameters(), 1e-1) #開始訓練 losses = [] acces = [] eval_losses = [] eval_acces = [] for e in range(20): train_loss = 0 train_acc = 0 net.train() for im, label in train_data: im = Variable(im) label = Variable(label) #前向傳播 out = net(im) loss = criterion(out, label) #反向傳播 optimizer.zero_grad() loss.backward() optimizer.step() #記錄誤差 train_loss += loss.item() #計算分類的準確率 _, pred = out.max(1) num_correct = (pred==label).sum().item() acc = num_correct / im.shape[0] train_acc += acc losses.append(train_loss / len(train_data)) acces.append(train_acc / len(train_data)) #在測試集上檢驗效果 eval_loss = 0 eval_acc = 0 net.eval() for im, label in test_data: im = Variable(im) label = Variable(label) out = net(im) loss = criterion(out, label) #記錄誤差 eval_loss += loss.item() #記錄準確率 _, pred = out. max(1) num_correct = (pred==label).sum().item() acc = num_correct / im.shape[0] eval_acc += acc eval_losses.append(eval_loss / len(test_data)) eval_acces.append(eval_acc / len(test_data)) print('epoch:{}, Train Loss:{:.6f}, Train Acc:{:.6f}, Eval Loss:{:.6f}, Eval Acc:{:.6f}' .format(e, train_loss/len(train_data), train_acc/len(train_data), eval_loss/len(test_data), eval_acc/len(test_data))) #最後一次的資料 #epoch:19, Train Loss:0.008779, Train Acc:0.997385, Eval Loss:0.072433, Eval Acc:0.982793 #畫出loss曲線和準確率曲線 import matplotlib.pyplot as plt plt.title('train loss') plt.plot(np.arange(len(losses)), losses) plt.plot(np.arange(len(acces)), acces) plt.title('train acc') plt.plot(np.arange(len(eval_losses)), eval_losses) plt.title('test loss') plt.plot(np.arange(len(eval_acces)), eval_acces) plt.title('test acc')
相關推薦
pytorch學習:MNIST手寫數字識別程式碼
# -*- coding: utf-8 -*- """ Created on Mon Sep 3 08:38:27 2018 @author: www """ import torch from torch import nn from torchvision.data
tensorflow 基礎學習五:MNIST手寫數字識別
truncate averages val flow one die correct 表示 data MNIST數據集介紹: from tensorflow.examples.tutorials.mnist import input_data # 載入MNIST數據集,
Caffe學習筆記(六):mnist手寫數字識別訓練例項
一、前言 深度學習的一個簡單例項就是mnist手寫數字識別,只要這個例子弄明白,其它的內容就可以舉一反三了。之前的內容如有遺忘,可以進行回顧。 二、準備資料 資料集可以直接從我的github下載,包括資料集和程式碼,使用如下指令:
【AI實戰】訓練第一個AI模型:MNIST手寫數字識別模型
在上篇文章中,我們已經把AI的基礎環境搭建好了(見文章:Ubuntu + conda + tensorflow + GPU + pycharm搭建AI基礎環境),接下來將基於tensorflow訓練第一個AI模型:MNIST手寫數字識別模型。 MNIST是一個經典的手寫數字資料集,來自美國國家
tensorflow實戰:MNIST手寫數字識別的優化2-代價函式優化,準確率98%
最簡單的tensorflow的手寫識別模型,這一節我們將會介紹其簡單的優化模型。我們會從代價函式,多層感知器,防止過擬合,以及優化器的等幾個方面來介紹優化過程。 1.代價函式的優化: 我們可以這樣將代價函式理解為真實值與預測值的差距,我們神經
Tensorflow案例4:Mnist手寫數字識別(線性神經網路)及其侷限性
學習目標 目標 應用matmul實現全連線層的計算 說明準確率的計算 應用softmax_cross_entropy_with_logits實現softamx以及交叉熵損失計算 說明全連線層在神經網路的作用 應用全連
從零開始學caffe(四):mnist手寫數字識別網路結構模型和超引數檔案的原始碼閱讀
下面為網路結構模型 %網路結構模型 name: "LeNet" #網路的名字"LeNet" layer { #定義一個層 name: "mnist" #層的名字"mnist" type:
訓練第一個AI模型:MNIST手寫數字識別模型
在上篇文章中,我們已經把AI的基礎環境搭建好了(見文章:Ubuntu + conda + tensorflow + GPU + pycharm搭建AI基礎環境),接下來將基於tensorflow訓練第一個AI模型:MNIST手寫數字識別模型。 MNIST是一個經典的手寫數字資料集,來自美國國家
TensorFlow筆記之一:MNIST手寫數字識別
本人剛剛開始接觸深度學習不久,對於tensorflow的瞭解也有限,想通過tensorflow這個框架來學習深度學習及其優化與識別。現在直接進入主題。 1.手寫識別的介紹: MNIST手寫識別在機器學習中就像c語言中Hello Wor
深度學習之mnist手寫數字識別入門
使用tensorflow框架和python,學習實現簡單的神經網路,並進行調參,程式碼如下: #! /usr/bin/python # -*- coding:utf-8 -*- """ a simple mnist classifier """ from __fut
OpenCV機器學習:SVM分類器實現MNIST手寫數字識別
0. 開發環境 最近機器學習隨著AI人工智慧的興起越來越火,博主想找一些ML的庫來練手。突然想起之前在看Opencv的doc時發現有ML的component,於是心血來潮就開始寫程式碼試試。話不多說,直接進正題。 以下我的開發環境配置: -Windows7
PyTorch基礎入門六:PyTorch搭建卷積神經網路實現MNIST手寫數字識別
1)卷積神經網路(CNN)簡介 關於什麼是卷積神經網路(CNN),請自行查閱資料進行學習。如果是初學者,這裡推薦一下臺灣的李巨集毅的深度學習課程。連結就不給了,這些資料網站上隨處可見。 值得一提的是,CNN雖然在影象處理的領域具有不可阻擋的勢頭,但是它絕對不僅僅只能用來影
Tensorflow深度學習之七:再談mnist手寫數字識別程式
之前學習的第一個深度學習的程式就是mnist手寫字型的識別,那個時候對於很多概念不是很理解,現在回過頭再看當時的程式碼,理解了很多,現將加了註釋的程式碼貼上,與大家分享。(本人還是在學習Tensorflow的初始階段,如果有什麼地方理解有誤,還請大家不吝指出。)
Tensorflow之MNIST手寫數字識別:分類問題(1)
一、MNIST資料集讀取 one hot 獨熱編碼獨熱編碼是一種稀疏向量,其中:一個向量設為1,其他元素均設為0.獨熱編碼常用於表示擁有有限個可能值的字串或識別符號優點: 1、將離散特徵的取值擴充套件到了歐式空間,離散特徵的某個取值就對應歐式空間的某個點 2、機器學習演算法中,
Tensorflow之MNIST手寫數字識別:分類問題(2)
整體程式碼: #資料讀取 import tensorflow as tf import matplotlib.pyplot as plt import numpy as np from tensorflow.examples.tutorials.mnist import input_data mnis
Tensorflow案例5:CNN演算法-Mnist手寫數字識別
學習目標 目標 應用tf.nn.conv2d實現卷積計算 應用tf.nn.relu實現啟用函式計算 應用tf.nn.max_pool實現池化層的計算 應用卷積神經網路實現影象分類識別 應用
深度學習之PyTorch —— CNN實現MNIST手寫數字分類
# 運用CNN分析MNIST手寫數字分類 import torch import numpy as np from torch.utils.data import DataLoader from torchvision.datasets import mnist fro
TensorFlow實現機器學習的“Hello World”--Mnist手寫數字識別
TensorFlow實現機器學習的“Hello World” 上一篇部落格我們已經說了TensorFlow大概怎麼使用,這次來說說機器學習中特別經典的案例,也相當於是機器學習的“Hello World”,他就是Mnist手寫數字識別,也就是通過訓練機器讓他能看
機器學習-實戰-入門-MNIST手寫數字識別
作者:橘子派 宣告:版權所有,轉載請註明出處,謝謝。 實驗環境: Windows10 Sublime Anaconda 1.6.0 Python3.6 程式碼功能包括: 一.ubyte資料集轉
Pytorch 快速入門(六)MNIST 手寫數字識別之CNN
接觸過深度學習的朋友對MNIST資料集肯定不陌生。基本上算是玩神經網路裡的“hello,world!”本節基於MNIST資料集,實現CNN學習過程。下面是一個 CNN 最後一層的學習過程, 我們先視覺化看看:MNIST手寫資料 首先是資料集下載,為了看看資料集裡究竟是長什麼樣