1. 程式人生 > >關於MNIST資料集無法用scikit載入的解決辦法

關於MNIST資料集無法用scikit載入的解決辦法

最近在看Hands-on machine learning with scikit-learn and tensorflow,在第三章載入MNIST資料時,出現了問題。

書中給的程式碼是:

from sklearn.datasets import fetch_mldata

mnist = fetch_mldata('MNIST original')

我的電腦是win7,64位,用的是python3.5。在載入時無法讀取bytes。所以在網上尋找辦法,大多都不行,包括本書github裡的解決辦法頁沒能解決。最後在簡書裡找到了一份程式碼,親測可用。

下面這份檔案可以在官網上找到,其他三個檔案也是。

train-images.idx3-ubyte

以下是經過我修改可以返回資料集的完整程式碼:

import numpy as np
import struct
import matplotlib.pyplot as plt

# 訓練集檔案
train_images_idx3_ubyte_file = 'C:\\Users\\Administrator\\PycharmProjects\\untitled\\python檔案包\\Hands-on machine learning with scikit-learn&tensorflow\\第三章\\mnist\\train-images.idx3-ubyte'
# 訓練集標籤檔案 train_labels_idx1_ubyte_file = 'C:\\Users\\Administrator\\PycharmProjects\\untitled\\python檔案包\\Hands-on machine learning with scikit-learn&tensorflow\\第三章\\mnist\\train-labels.idx1-ubyte' # 測試集檔案 test_images_idx3_ubyte_file = 'C:\\Users\\Administrator\\PycharmProjects\\untitled\\python檔案包
\\Hands-on machine learning with scikit-learn&tensorflow\\第三章\\mnist\\t10k-images.idx3-ubyte' # 測試集標籤檔案 test_labels_idx1_ubyte_file = 'C:\\Users\\Administrator\\PycharmProjects\\untitled\\python檔案包\\Hands-on machine learning with scikit-learn&tensorflow\\第三章\\mnist\\t10k-labels.idx1-ubyte' def decode_idx3_ubyte(idx3_ubyte_file): """ 解析idx3檔案的通用函式 :param idx3_ubyte_file: idx3檔案路徑 :return: 資料集 """ # 讀取二進位制資料 bin_data = open(idx3_ubyte_file, 'rb').read() # 解析檔案頭資訊,依次為魔數、圖片數量、每張圖片高、每張圖片寬 offset = 0 fmt_header = '>iiii' magic_number, num_images, num_rows, num_cols = struct.unpack_from(fmt_header, bin_data, offset) print('魔數:%d, 圖片數量: %d張, 圖片大小: %d*%d' % (magic_number, num_images, num_rows, num_cols)) # 解析資料集 image_size = num_rows * num_cols offset += struct.calcsize(fmt_header) fmt_image = '>' + str(image_size) + 'B' images = np.empty((num_images, num_rows, num_cols)) for i in range(num_images): if (i + 1) % 10000 == 0: print('已解析 %d' % (i + 1) + '張') images[i] = np.array(struct.unpack_from(fmt_image, bin_data, offset)).reshape((num_rows, num_cols)) offset += struct.calcsize(fmt_image) return images def decode_idx1_ubyte(idx1_ubyte_file): """ 解析idx1檔案的通用函式 :param idx1_ubyte_file: idx1檔案路徑 :return: 資料集 """ # 讀取二進位制資料 bin_data = open(idx1_ubyte_file, 'rb').read() # 解析檔案頭資訊,依次為魔數和標籤數 offset = 0 fmt_header = '>ii' magic_number, num_images = struct.unpack_from(fmt_header, bin_data, offset) print('魔數:%d, 圖片數量: %d張' % (magic_number, num_images)) # 解析資料集 offset += struct.calcsize(fmt_header) fmt_image = '>B' labels = np.empty(num_images) for i in range(num_images): if (i + 1) % 10000 == 0: print ('已解析 %d' % (i + 1) + '張') labels[i] = struct.unpack_from(fmt_image, bin_data, offset)[0] offset += struct.calcsize(fmt_image) return labels def load_train_images(idx_ubyte_file=train_images_idx3_ubyte_file): """ TRAINING SET IMAGE FILE (train-images-idx3-ubyte): [offset] [type] [value] [description] 0000 32 bit integer 0x00000803(2051) magic number 0004 32 bit integer 60000 number of images 0008 32 bit integer 28 number of rows 0012 32 bit integer 28 number of columns 0016 unsigned byte ?? pixel 0017 unsigned byte ?? pixel ........ xxxx unsigned byte ?? pixel Pixels are organized row-wise. Pixel values are 0 to 255. 0 means background (white), 255 means foreground (black). :param idx_ubyte_file: idx檔案路徑 :return: n*row*col維np.array物件,n為圖片數量 """ return decode_idx3_ubyte(idx_ubyte_file) def load_train_labels(idx_ubyte_file=train_labels_idx1_ubyte_file): """ TRAINING SET LABEL FILE (train-labels-idx1-ubyte): [offset] [type] [value] [description] 0000 32 bit integer 0x00000801(2049) magic number (MSB first) 0004 32 bit integer 60000 number of items 0008 unsigned byte ?? label 0009 unsigned byte ?? label ........ xxxx unsigned byte ?? label The labels values are 0 to 9. :param idx_ubyte_file: idx檔案路徑 :return: n*1維np.array物件,n為圖片數量 """ return decode_idx1_ubyte(idx_ubyte_file) def load_test_images(idx_ubyte_file=test_images_idx3_ubyte_file): """ TEST SET IMAGE FILE (t10k-images-idx3-ubyte): [offset] [type] [value] [description] 0000 32 bit integer 0x00000803(2051) magic number 0004 32 bit integer 10000 number of images 0008 32 bit integer 28 number of rows 0012 32 bit integer 28 number of columns 0016 unsigned byte ?? pixel 0017 unsigned byte ?? pixel ........ xxxx unsigned byte ?? pixel Pixels are organized row-wise. Pixel values are 0 to 255. 0 means background (white), 255 means foreground (black). :param idx_ubyte_file: idx檔案路徑 :return: n*row*col維np.array物件,n為圖片數量 """ return decode_idx3_ubyte(idx_ubyte_file) def load_test_labels(idx_ubyte_file=test_labels_idx1_ubyte_file): """ TEST SET LABEL FILE (t10k-labels-idx1-ubyte): [offset] [type] [value] [description] 0000 32 bit integer 0x00000801(2049) magic number (MSB first) 0004 32 bit integer 10000 number of items 0008 unsigned byte ?? label 0009 unsigned byte ?? label ........ xxxx unsigned byte ?? label The labels values are 0 to 9. :param idx_ubyte_file: idx檔案路徑 :return: n*1維np.array物件,n為圖片數量 """ return decode_idx1_ubyte(idx_ubyte_file) def run(): train_images = load_train_images() train_labels = load_train_labels() test_images = load_test_images() test_labels = load_test_labels() X_train= train_images X_test = test_images y_train = train_labels y_test = test_labels return X_train,X_test,y_train,y_test print('train set label number is '+str(len(y_train))) print('test set label number is '+str(len(y_test))) # 檢視前十個資料及其標籤以讀取是否正確 for i in range(10): print(train_labels[i]) plt.imshow(train_images[i], cmap='gray') plt.show() print('done') if __name__ == '__main__': X_train, X_test, y_train, y_test = run()