1. 程式人生 > >numpy方法讀取載入mnist資料集

numpy方法讀取載入mnist資料集

方法來自機器之心公眾號

首先下載mnist資料集,並將裡面四個資料夾解壓出來,下載方法見前面的部落格

import tensorflow as tf
import numpy as np
import os

dataset_path = r'D:\PycharmProjects\tensorflow\MNIST_data' # 這是我存放mnist資料集的位置
is_training = True


# 定義載入mnist的函式
def load_mnist(path, is_training):

    # trX將載入儲存所有60000張灰度圖
    fd = open(os.path.join(path, 'train-images.idx3-ubyte'))
    loaded = np.fromfile(file=fd, dtype=np.uint8)
    trX = loaded[16:].reshape((60000, 28, 28, 1)).astype(np.float)

    fd = open(os.path.join(path, 'train-labels.idx1-ubyte'))
    loaded = np.fromfile(file=fd, dtype=np.uint8)
    trY = loaded[8:].reshape((60000)).astype(np.float)

    #teX將儲存所有一萬張測試用的圖片
    fd = open(os.path.join(path, 't10k-images.idx3-ubyte'))
    loaded = np.fromfile(file=fd, dtype=np.uint8)
    teX = loaded[16:].reshape((10000, 28, 28, 1)).astype(np.float)

    fd = open(os.path.join(path, 't10k-labels.idx1-ubyte'))
    loaded = np.fromfile(file=fd, dtype=np.uint8)
    teY = loaded[8:].reshape((10000)).astype(np.float)

    # 將所有訓練圖片表示為一個4維張量 [60000, 28, 28, 1],其中每個畫素值縮放到0和1之間
    trX = tf.convert_to_tensor(trX / 255., tf.float32)

    # one hot編碼為 [num_samples, 10]
    trY = tf.one_hot(trY, depth=10, axis=1, dtype=tf.float32)
    teY = tf.one_hot(teY, depth=10, axis=1, dtype=tf.float32)

    # 訓練和測試時返回不同的資料
    if is_training:
        return trX, trY
    else:
        return teX / 255., teY


def get_batch_data():
    trX, trY = load_mnist(dataset_path, True)

    # 每次產生一個切片,每次從一個tensor列表中按順序或者隨機抽取出一個tensor放入檔名佇列
    data_queues = tf.train.slice_input_producer([trX, trY])

    # 對佇列中的樣本進行亂序處理
    X, Y = tf.train.shuffle_batch(data_queues,
                                  batch_size=batch_size,
                                  capacity=batch_size * 64,
                                  min_after_dequeue=batch_size * 32,
                                  allow_smaller_final_batch=False)
    return (X, Y)

這裡為什麼要去掉訓練集的前16個數字和標籤的前8個數字呢?我看了一下,訓練集train-images.idx3-ubyte檔案確實有47040016個數字,比28*28*60000=47040000多了16個數字,訓練集標籤train-labels.idx1-ubyte檔案下有60008個數字,也多出來8個數字,下面是mnist訓練集的樣本和標籤的資料結構:

 

可以看出在train-images.idx3-ubyte中,第一個數為32位的整數(魔數,圖片型別的數),第二個數為32位的整數(圖片的個數),第三和第四個也是32為的整數(分別代表圖片的行數和列數),接下來的都是一個位元組的無符號數(即畫素,值域為0~255)