1. 程式人生 > >MNIST 資料集讀取和視覺化

MNIST 資料集讀取和視覺化

MNIST 資料集已經是一個被”嚼爛”了的資料集, 很多教程都會對它”下手”, 幾乎成為一個 “典範”. 不過有些人可能對它還不是很瞭解, 下面來介紹一下.

  • Training set images: train-images-idx3-ubyte.gz (9.9 MB, 解壓後 47 MB, 包含 60,000 個樣本)
  • Training set labels: train-labels-idx1-ubyte.gz (29 KB, 解壓後 60 KB, 包含 60,000 個標籤)
  • Test set images: t10k-images-idx3-ubyte.gz (1.6 MB, 解壓後 7.8 MB, 包含 10,000 個樣本)
  • Test set labels: t10k-labels-idx1-ubyte.gz (5KB, 解壓後 10 KB, 包含 10,000 個標籤)

MNIST 資料集來自美國國家標準與技術研究所, National Institute of Standards and Technology (NIST). 訓練集 (training set) 由來自 250 個不同人手寫的數字構成, 其中 50% 是高中學生, 50% 來自人口普查局 (the Census Bureau) 的工作人員. 測試集(test set) 也是同樣比例的手寫數字資料.

不妨新建一個資料夾 – mnist, 將資料集下載到 mnist 以後, 解壓即可:

dataset

圖片是以位元組的形式進行儲存, 我們需要把它們讀取到 NumPy array 中, 以便訓練和測試演算法.

import os
import struct
import numpy as np

def load_mnist(path, kind='train'):
    """Load MNIST data from `path`"""
    labels_path = os.path.join(path,
                               '%s-labels-idx1-ubyte'
                               % kind)
    images_path = os.path.join(path,
                               '%s-images-idx3-ubyte'
% kind) with open(labels_path, 'rb') as lbpath: magic, n = struct.unpack('>II', lbpath.read(8)) labels = np.fromfile(lbpath, dtype=np.uint8) with open(images_path, 'rb') as imgpath: magic, num, rows, cols = struct.unpack('>IIII', imgpath.read(16)) images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(labels), 784) return images, labels

load_mnist 函式返回兩個陣列, 第一個是一個 n x m 維的 NumPy array(images), 這裡的 n 是樣本數(行數), m 是特徵數(列數). 訓練資料集包含 60,000 個樣本, 測試資料集包含 10,000 樣本. 在 MNIST 資料集中的每張圖片由 28 x 28 個畫素點構成, 每個畫素點用一個灰度值表示. 在這裡, 我們將 28 x 28 的畫素展開為一個一維的行向量, 這些行向量就是圖片數組裡的行(每行 784 個值, 或者說每行就是代表了一張圖片). load_mnist 函式返回的第二個陣列(labels) 包含了相應的目標變數, 也就是手寫數字的類標籤(整數 0-9).

第一次見的話, 可能會覺得我們讀取圖片的方式有點奇怪:

magic, n = struct.unpack('>II', lbpath.read(8))
labels = np.fromfile(lbpath, dtype=np.uint8)
  • 為了理解這兩行程式碼, 我們先來看一下 MNIST 網站上對資料集的介紹:
TRAINING SET LABEL FILE (train-labels-idx1-ubyte):

[offset] [type]          [value]          [description] 
0000     32 bit integer  0x00000801(2049) magic number (MSB first) 
0004     32 bit integer  60000            number of items 
0008     unsigned byte   ??               label 
0009     unsigned byte   ??               label 
........ 
xxxx     unsigned byte   ??               label
The labels values are 0 to 9.
  • 通過使用上面兩行程式碼, 我們首先讀入 magic number, 它是一個檔案協議的描述, 也是在我們呼叫 fromfile 方法將位元組讀入 NumPy array 之前在檔案緩衝中的 item 數(n). 作為引數值傳入 struct.unpack 的 >II 有兩個部分:
  • >: 這是指大端(用來定義位元組是如何儲存的); 如果你還不知道什麼是大端和小端, Endianness 是一個非常好的解釋. (關於大小端, 更多內容可見<<深入理解計算機系統 – 2.1 節資訊儲存>>)
  • I: 這是指一個無符號整數.

通過執行下面的程式碼, 我們將會從剛剛解壓 MNIST 資料集後的 mnist 目錄下載入 60,000 個訓練樣本和 10,000 個測試樣本.

為了瞭解 MNIST 中的圖片看起來到底是個啥, 讓我們來對它們進行視覺化處理. 從 feature matrix 中將 784-畫素值 的向量 reshape 為之前的 28*28 的形狀, 然後通過 matplotlib 的 imshow 函式進行繪製:

import matplotlib.pyplot as plt

fig, ax = plt.subplots(
    nrows=2,
    ncols=5,
    sharex=True,
    sharey=True, )

ax = ax.flatten()
for i in range(10):
    img = X_train[y_train == i][0].reshape(28, 28)
    ax[i].imshow(img, cmap='Greys', interpolation='nearest')

ax[0].set_xticks([])
ax[0].set_yticks([])
plt.tight_layout()
plt.show()

我們現在應該可以看到一個 2*5 的圖片, 裡面分別是 0-9 單個數字的圖片.

0-9

此外, 我們還可以繪製某一數字的多個樣本圖片, 來看一下這些手寫樣本到底有多不同:

fig, ax = plt.subplots(
    nrows=5,
    ncols=5,
    sharex=True,
    sharey=True, )

ax = ax.flatten()
for i in range(25):
    img = X_train[y_train == 7][i].reshape(28, 28)
    ax[i].imshow(img, cmap='Greys', interpolation='nearest')

ax[0].set_xticks([])
ax[0].set_yticks([])
plt.tight_layout()
plt.show()

執行上面的程式碼後, 我們應該看到數字 7 的 25 個不同形態:

7

另外, 我們也可以選擇將 MNIST 圖片資料和標籤儲存為 CSV 檔案, 這樣就可以在不支援特殊的位元組格式的程式中開啟資料集. 但是, 有一點要說明, CSV 的檔案格式將會佔用更多的磁碟空間, 如下所示:

  • train_img.csv: 109.5 MB
  • train_labels.csv: 120 KB
  • test_img.csv: 18.3 MB
  • test_labels: 20 KB

如果我們打算儲存這些 CSV 檔案, 在將 MNIST 資料集載入入 NumPy array 以後, 我們應該執行下列程式碼:

np.savetxt('train_img.csv', X_train,
           fmt='%i', delimiter=',')
np.savetxt('train_labels.csv', y_train,
           fmt='%i', delimiter=',')
np.savetxt('test_img.csv', X_test,
           fmt='%i', delimiter=',')
np.savetxt('test_labels.csv', y_test,
           fmt='%i', delimiter=',')

一旦將資料集儲存為 CSV 檔案, 我們也可以用 NumPy 的 genfromtxt 函式重新將它們載入入程式中:

X_train = np.genfromtxt('train_img.csv',
                        dtype=int, delimiter=',')
y_train = np.genfromtxt('train_labels.csv',
                        dtype=int, delimiter=',')
X_test = np.genfromtxt('test_img.csv',
                       dtype=int, delimiter=',')
y_test = np.genfromtxt('test_labels.csv',
                       dtype=int, delimiter=',')

不過, 從 CSV 檔案中載入 MNIST 資料將會顯著發給更長的時間, 因此如果可能的話, 還是建議你維持資料集原有的位元組格式.

出處:https://blog.csdn.net/simple_the_best/article/details/75267863