1. 程式人生 > >cifar10資料格式以及讀取方式

cifar10資料格式以及讀取方式

cifar10 資料網站
http://www.cs.toronto.edu/~kriz/cifar.html

讀取下面的檔案

CIFAR-10 binary version (suitable for C programs)	162 MB	c32a1d4ab5d03f1284b67883e8d87530

下載cifar-10-binary.tar.gz 到./data/資料夾下

cd ./data/

解壓下載後的檔案到./data/

tar -xzvf cifar-10-binary.tar.gz

會出現一個資料夾 ‘cifar-10-batches-bin’
以及資料夾下的這些檔案

batches.meta.txt
data_batch_1.bin    
data_batch_2.bin    
data_batch_3.bin    
data_batch_4.bin    
data_batch_5.bin    
readme.html         
test_batch.bin

data_batch_1.bin 到 data_batch_5.bin 是訓練集二進位制檔案
每個檔案中有10000張圖片和10000個標記
共有50000張圖片和50000個標記
每個二進位制檔案第一個位元組是標記,後面的32x32x3是圖片,
圖片中前32x32 是 red channel, 接著32x32是 green channel,然後32x32 是blue channel. 然後依次類推.
依次類推.
test_batch.bin 是測試集檔案

每個二進位制檔案是30730000個位元組

下面是讀取資料集的類檔案
cifar10_dataset.py

#!/usr/bin/python
# -*- coding: utf-8 -*-
#####################################
# File name : cifar10_dataset.py
# Create date : 2018-12-24 19:58
# Modified date : 2018-12-31 16:21
# Author : DARREN
# Describe : not set
# Email : [email protected]
##################################### from __future__ import division from __future__ import print_function #http://www.cs.toronto.edu/~kriz/cifar.html import sys import os import struct import numpy as np import matplotlib.pyplot as plt # pylint: disable=bad-continuation meta_lt = [ "airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck", ] # pylint: enable=bad-continuation def create_path(path): if not os.path.isdir(path): os.makedirs(path) def open_file_with_full_name(full_path, open_type): try: file_object = open(full_path, open_type) return file_object except Exception as e: print(e) return None def get_file_full_name(path, name): if path[-1] == "/": full_name = path + name else: full_name = path + "/" + name return full_name def open_file(path, name, open_type='a'): file_name = get_file_full_name(path, name) return open_file_with_full_name(file_name, open_type) def _get_file_header_data(file_obj, header_len, unpack_str): raw_header = file_obj.read(header_len) header_data = struct.unpack(unpack_str, raw_header) return header_data def _read_a_image(file_object): raw_img = file_object.read(32 * 32) red_img = struct.unpack(">1024B", raw_img) raw_img = file_object.read(32 * 32) green_img = struct.unpack(">1024B", raw_img) raw_img = file_object.read(32 * 32) blue_img = struct.unpack(">1024B", raw_img) img = np.zeros(shape=(1024, 3)) for i in range(1024): l = [red_img[i], green_img[i], blue_img[i]] img[i] = l img = img.reshape(32, 32, 3) img = img / 255. return img def _read_one_image(file_object): raw_img = file_object.read(32 * 32 * 3) img = struct.unpack(">3072B", raw_img) return img def _read_a_label(file_object): raw_label = file_object.read(1) label = struct.unpack(">B", raw_label) return label def _get_image_full_name(path, label, count): meta = meta_lt[label[0]] full_path = "%s%s" %(path, meta) create_path(full_path) full_path_name = "%s/%s.jpg" %(full_path, count) return full_path_name def save_image(image, full_path_name): plt.imshow(image) plt.savefig(full_path_name) plt.close() class Cifar10Set(object): def __init__(self, file_path): super(Cifar10Set, self).__init__() # pylint: disable=bad-continuation self._train_file_list = [ "data_batch_1.bin", "data_batch_2.bin", "data_batch_3.bin", "data_batch_4.bin", "data_batch_5.bin" ] # pylint: enable=bad-continuation self._test_file_list = ["test_batch.bin",] self.file_path = file_path def _read_file(self, file_name): file_object = open_file(self.file_path, file_name, open_type="rb") return file_object def _generate_a_batch(self, batch_size, file_list): images = np.zeros(shape=(batch_size, 32 * 32 * 3)) labels = np.zeros(shape=(batch_size, 10)) i = 0 file_name = file_list[i] file_name = "cifar-10-batches-bin/%s" % file_name train_file = self._read_file(file_name) count = 0 ret = True while True: while count < batch_size: try: label = _read_a_label(train_file) image = _read_one_image(train_file) images[count] = image labels[count][label[0]] = 1 count += 1 except Exception as err: #print(err) if i >= len(self._train_file_list): ret = False break else: i += 1 if i < len(file_list): file_name = file_list[i] file_name = "cifar-10-batches-bin/%s" % file_name train_file = self._read_file(file_name) count = 0 yield images, labels.astype(int), ret images = np.zeros(shape=(batch_size, 32*32*3)) labels = np.zeros(shape=(batch_size, 10)) def generator_images(self, file_list, path): count = 1 for i in range(len(file_list)): file_name = file_list[i] file_name = "cifar-10-batches-bin/%s" % file_name train_file = self._read_file(file_name) while True: try: label = _read_a_label(train_file) image = _read_a_image(train_file) full_path_name = _get_image_full_name(path, label, count) save_image(image, full_path_name) print("file:%s count:%s"% (file_name, count)) except Exception as err: print(err) break count += 1 def generator_train_images(self, path): self.generator_images(self._train_file_list, path) def generator_test_images(self, path): self.generator_images(self._test_file_list, path) def get_train_data_generator(self, batch_size=128): file_list = self._train_file_list gennerator = self._generate_a_batch(batch_size, file_list) return gennerator def get_test_data_generator(self, batch_size=128): file_list = self._test_file_list gennerator = self._generate_a_batch(batch_size, file_list) return gennerator def get_a_batch_data(self, data_generator): if sys.version > '3': batch_img, batch_labels, status = data_generator.__next__() else: batch_img, batch_labels, status = data_generator.next() return batch_img, batch_labels, status

下面是main.py

#!/usr/bin/python
# -*- coding: utf-8 -*-
#####################################
# File name : main.py
# Create date : 2018-12-23 16:53
# Modified date : 2018-12-31 15:37
# Author : DARREN
# Describe : not set
# Email : [email protected]
#####################################
from __future__ import division
from __future__ import print_function

import cifar10_dataset

def test_cifar10_train_set():
    file_path = "./data/"
    batch_size = 100
    dataset = cifar10_dataset.Cifar10Set(file_path)
    data_generator = dataset.get_train_data_generator(batch_size)
    count = 1
    while True:
        batch_img, batch_labels, status = dataset.get_a_batch_data(data_generator)
        print("count:%s status:%s " % (count, status))
        if not status:
            break
        count += 1
        print(str(batch_labels))

def test_cifar10_test_set():
    file_path = "./data/"
    batch_size = 100
    dataset = cifar10_dataset.Cifar10Set(file_path)
    data_generator = dataset.get_test_data_generator(batch_size)
    count = 1
    while True:
        batch_img, batch_labels, status = dataset.get_a_batch_data(data_generator)
        print("count:%s status:%s " % (count, status))
        if not status:
            break
        count += 1
        print(str(batch_labels))

def test_generator_images():
    test_generator_train_images()
    test_generator_test_images()

def test_generator_train_images():
    file_path = "./data/"
    train_img_path = "./img/train/"
    dataset = cifar10_dataset.Cifar10Set(file_path)
    dataset.generator_train_images(train_img_path)

def test_generator_test_images():
    file_path = "./data/"
    test_img_path = "./img/test/"
    dataset = cifar10_dataset.Cifar10Set(file_path)
    dataset.generator_test_images(test_img_path)

def run():
	test_cifar10_train_set()
    test_cifar10_test_set()
    test_generator_images()
    
run()

上面的程式碼可以在python2 以及python3 執行.可以批量讀取cifar10的所有訓練以及測試資料.
而且通過matplotlib 把二進位制的資料轉換儲存成了圖片.使我們可以看到這些圖片真實的樣子.
下面是一些儲存的圖片.
第一張圖片是青蛙 ,放成這麼大的話,我作為人表示還是這比較難看出來的.但是離遠點看,或者縮小了後,就有點像青蛙了.
在這裡插入圖片描述

也有些比較容易辨識的圖片 比如第二章卡車圖片, 這個還比較容易看出來.
在這裡插入圖片描述

再放上來一些
在這裡插入圖片描述

在這裡插入圖片描述

在這裡插入圖片描述

在這裡插入圖片描述

在這裡插入圖片描述

在這裡插入圖片描述

在這裡插入圖片描述

在這裡插入圖片描述