使用 Java 讀取 Mnist 資料集

0. 前言

上週一時衝動了,決定自己 code 一下 mlp (多層感知機)。最後的測試部分使用它來識別手寫數字,也就是在 MNIST 資料集上訓練並測試效果。在讀取 MNIST 資料集時本打算使用輪子,可並沒找到使用 Java 創造的輪子。於是,根據官網的儲存格式說明自己寫了一個。


1. MNIST 資料集

      手寫數字 0 的樣例圖
  • 資料集的格式
    • IMAGE FILE (以 train-images-idx3-ubyte 為例)
      [offset] [type]          [value]          [description] 
      0000     32 bit integer  0x00000803(2051) magic number          // 魔數,就像 java 類檔案中的 “CAFEBABE”。可視為一種驗證,其實沒有~~
      0004     32 bit integer  60000            number of images      // 表明一共有 60000 中樣例
      0008     32 bit integer  28               number of rows        // 一行含有的畫素點數
      0012     32 bit integer  28               number of columns     // 一列含有的畫素點數
      0016     unsigned byte   ??               pixel                 // 對應畫素點的值(0 ~ 255)
      0017     unsigned byte   ??               pixel 
      xxxx     unsigned byte   ??               pixel
    • LABEL FILE (以 train-labels-idx1-ubyte 為例)
      [offset] [type]          [value]          [description] 
      0000     32 bit integer  0x00000801(2049) magic number (MSB first)  // 同上
      0004     32 bit integer  60000            number of items           // 同上
      0008     unsigned byte   ??               label                     // 對應樣本的標籤,即對應影象中的手寫數字是幾(0 ~ 9)
      0009     unsigned byte   ??               label 
      xxxx     unsigned byte   ??               label

2. 程式碼

  • 讀取資料集程式碼
import java.io.BufferedInputStream;
import java.io.FileInputStream;
import java.io.IOException;

public class MnistRead {

    public static final String TRAIN_IMAGES_FILE = "data/mnist/train-images.idx3-ubyte";
    public static final String TRAIN_LABELS_FILE = "data/mnist/train-labels.idx1-ubyte";
    public static final String TEST_IMAGES_FILE = "data/mnist/t10k-images.idx3-ubyte";
    public static final String TEST_LABELS_FILE = "data/mnist/t10k-labels.idx1-ubyte";

     * change bytes into a hex string.
     * @param bytes bytes
     * @return the returned hex string
    public static String bytesToHex(byte[] bytes) {
        StringBuffer sb = new StringBuffer();
        for (int i = 0; i < bytes.length; i++) {
            String hex = Integer.toHexString(bytes[i] & 0xFF);
            if (hex.length() < 2) {
        return sb.toString();

     * get images of 'train' or 'test'
     * @param fileName the file of 'train' or 'test' about image
     * @return one row show a `picture`
    public static double[][] getImages(String fileName) {
        double[][] x = null;
        try (BufferedInputStream bin = new BufferedInputStream(new FileInputStream(fileName))) {
            byte[] bytes = new byte[4];
            bin.read(bytes, 0, 4);
            if (!"00000803".equals(bytesToHex(bytes))) {                        // 讀取魔數
                throw new RuntimeException("Please select the correct file!");
            } else {
                bin.read(bytes, 0, 4);
                int number = Integer.parseInt(bytesToHex(bytes), 16);           // 讀取樣本總數
                bin.read(bytes, 0, 4);
                int xPixel = Integer.parseInt(bytesToHex(bytes), 16);           // 讀取每行所含畫素點數
                bin.read(bytes, 0, 4);
                int yPixel = Integer.parseInt(bytesToHex(bytes), 16);           // 讀取每列所含畫素點數
                x = new double[number][xPixel * yPixel];
                for (int i = 0; i < number; i++) {
                    double[] element = new double[xPixel * yPixel];
                    for (int j = 0; j < xPixel * yPixel; j++) {
                        element[j] = bin.read();                                // 逐一讀取畫素值
                        // normalization
//                        element[j] = bin.read() / 255.0;
                    x[i] = element;
        } catch (IOException e) {
            throw new RuntimeException(e);
        return x;

     * get labels of `train` or `test`
     * @param fileName the file of 'train' or 'test' about label
     * @return
    public static double[] getLabels(String fileName) {
        double[] y = null;
        try (BufferedInputStream bin = new BufferedInputStream(new FileInputStream(fileName))) {
            byte[] bytes = new byte[4];
            bin.read(bytes, 0, 4);
            if (!"00000801".equals(bytesToHex(bytes))) {
                throw new RuntimeException("Please select the correct file!");
            } else {
                bin.read(bytes, 0, 4);
                int number = Integer.parseInt(bytesToHex(bytes), 16);
                y = new double[number];
                for (int i = 0; i < number; i++) {
                    y[i] = bin.read();
        } catch (IOException e) {
            throw new RuntimeException(e);
        return y;

    public static void main(String[] args) {
        double[][] images = getImages(TRAIN_IMAGES_FILE);
        double[] labels = getLabels(TRAIN_LABELS_FILE);

        double[][] images = getImages(TEST_IMAGES_FILE);
        double[] labels = getLabels(TEST_LABELS_FILE);

  • 顯示影象程式碼
 * draw a gray picture and the image format is JPEG.
 * @param pixelValues pixelValues and ordered by column.
 * @param width       width
 * @param high        high
 * @param fileName    image saved file.
public static void drawGrayPicture(int[] pixelValues, int width, int high, String fileName) throws IOException {
    BufferedImage bufferedImage = new BufferedImage(width, high, BufferedImage.TYPE_INT_RGB);
    for (int i = 0; i < width; i++) {
        for (int j = 0; j < high; j++) {
            int pixel = 255 - pixelValues[i * high + j];
            int value = pixel + (pixel << 8) + (pixel << 16);   // r = g = b 時,正好為灰度
            bufferedImage.setRGB(j, i, value);
    ImageIO.write(bufferedImage, "JPEG", new File(fileName));

3. 還有什麼


