1. 程式人生 > >python讀取cifar10資料集

python讀取cifar10資料集

最近學習卷積網路用到cifar10資料集,自己寫了一個工具類,用來讀取已經下載到本地的cifar10資料集。

程式碼寫的不算好,但是自己用起來還可以。所以放到網上,有需要的可以拿去用。程式碼比較少,所以沒有寫註釋。下面介紹一下實現的功能。完整的程式碼可以在github上下載。地址:https://github.com/NewQJX/DeepLearning/tree/master/Cifar10

檔名為:input_data.py

建立了一個類Cifar10():用於讀取本地資料集,對資料集進行操作

__init__(self, path, one_hot = True): 引數path為本地資料集儲存路徑。one_hot:決定是否對類別獨熱編碼

_load_data():用於載入資料集

next_batch(batch_size, shuffle = True): 該方法返回指定batch_size大小的訓練集, shuffle:決定是否打亂順序

下面是使用該類的方法:

import input_data
import numpy as np

path = r"E:\pythonCode\TensorFlow\cifar10\cifar-10-batches-py"
cifar10 = input_data.load_cifar10(path, one_hot = True)
images = cifar10.images
print("訓練集圖片:" + str(images.shape))
labels = cifar10.labels
print("訓練集類別:" + str(labels.shape))
test_images = cifar10.test.images
print("測試集圖片:"+ str(test_images.shape))
test_labels = cifar10.test.labels
print("測試集類別:"+ str(test_labels.shape))
batch_xs, batch_ys = cifar10.next_batch(batch_size = 500, shuffle = True)
print("batch_xs shape is:" + str(batch_xs.shape))
print("batch_ys shape is:" + str(batch_ys.shape))