1. 程式人生 > >使用 Java 讀取 MNIST 資料集

使用 Java 讀取 MNIST 資料集

使用 Java 讀取 Mnist 資料集

0. 前言

好久沒寫 blog 了,沒有堅持住,心中滿滿的負罪感!!!

上週一時衝動了,決定自己 code 一下 mlp (多層感知機)。最後的測試部分使用它來識別手寫數字,也就是在 MNIST 資料集上訓練並測試效果。在讀取 MNIST 資料集時本打算使用輪子,可並沒找到使用 Java 創造的輪子。於是,根據官網的儲存格式說明自己寫了一個。

遂得此文,望可拋磚引玉~~(廢話少說!)

1. MNIST 資料集

  • THE MNIST DATABASE of handwritten digits
    • MachineLearing
      中非常出名的資料集,它以二進位制的形式儲存了每個手寫數字的畫素及標籤。下面是視覺化後的一個樣例圖。
      手寫數字 0 的樣例圖
    • 其他資訊詳見官網(點選上面的小標題可以直接進入)
  • 資料集的格式
    • IMAGE FILE (以 train-images-idx3-ubyte 為例)
      [offset] [type]          [value]          [description] 
      0000     32 bit integer  0x00000803(2051) magic number          // 魔數,就像 java 類檔案中的 “CAFEBABE”。可視為一種驗證,其實沒有~~
      0004     32 bit integer  60000            number of images      // 表明一共有 60000 中樣例
      0008     32 bit integer  28               number of rows        // 一行含有的畫素點數
      0012     32 bit integer  28               number of columns     // 一列含有的畫素點數
      0016     unsigned byte   ??               pixel                 // 對應畫素點的值(0 ~ 255)
      0017     unsigned byte   ??               pixel 
      ........ 
      xxxx     unsigned byte   ??               pixel
      
    • LABEL FILE (以 train-labels-idx1-ubyte 為例)
      [offset] [type]          [value]          [description] 
      0000     32 bit integer  0x00000801(2049) magic number (MSB first)  // 同上
      0004     32 bit integer  60000            number of items           // 同上
      0008     unsigned byte   ??               label                     // 對應樣本的標籤,即對應影象中的手寫數字是幾(0 ~ 9)
      0009     unsigned byte   ??               label 
      ........ 
      xxxx     unsigned byte   ??               label
      

2. 程式碼

  • 讀取資料集程式碼
import java.io.BufferedInputStream;
import java.io.FileInputStream;
import java.io.IOException;

public class MnistRead {

    public static final String TRAIN_IMAGES_FILE = "data/mnist/train-images.idx3-ubyte";
    public static final String TRAIN_LABELS_FILE = "data/mnist/train-labels.idx1-ubyte";
    public static final String TEST_IMAGES_FILE = "data/mnist/t10k-images.idx3-ubyte";
    public static final String TEST_LABELS_FILE = "data/mnist/t10k-labels.idx1-ubyte";

    /**
     * change bytes into a hex string.
     *
     * @param bytes bytes
     * @return the returned hex string
     */
    public static String bytesToHex(byte[] bytes) {
        StringBuffer sb = new StringBuffer();
        for (int i = 0; i < bytes.length; i++) {
            String hex = Integer.toHexString(bytes[i] & 0xFF);
            if (hex.length() < 2) {
                sb.append(0);
            }
            sb.append(hex);
        }
        return sb.toString();
    }

    /**
     * get images of 'train' or 'test'
     *
     * @param fileName the file of 'train' or 'test' about image
     * @return one row show a `picture`
     */
    public static double[][] getImages(String fileName) {
        double[][] x = null;
        try (BufferedInputStream bin = new BufferedInputStream(new FileInputStream(fileName))) {
            byte[] bytes = new byte[4];
            bin.read(bytes, 0, 4);
            if (!"00000803".equals(bytesToHex(bytes))) {                        // 讀取魔數
                throw new RuntimeException("Please select the correct file!");
            } else {
                bin.read(bytes, 0, 4);
                int number = Integer.parseInt(bytesToHex(bytes), 16);           // 讀取樣本總數
                bin.read(bytes, 0, 4);
                int xPixel = Integer.parseInt(bytesToHex(bytes), 16);           // 讀取每行所含畫素點數
                bin.read(bytes, 0, 4);
                int yPixel = Integer.parseInt(bytesToHex(bytes), 16);           // 讀取每列所含畫素點數
                x = new double[number][xPixel * yPixel];
                for (int i = 0; i < number; i++) {
                    double[] element = new double[xPixel * yPixel];
                    for (int j = 0; j < xPixel * yPixel; j++) {
                        element[j] = bin.read();                                // 逐一讀取畫素值
                        // normalization
//                        element[j] = bin.read() / 255.0;
                    }
                    x[i] = element;
                }
            }
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
        return x;
    }

    /**
     * get labels of `train` or `test`
     *
     * @param fileName the file of 'train' or 'test' about label
     * @return
     */
    public static double[] getLabels(String fileName) {
        double[] y = null;
        try (BufferedInputStream bin = new BufferedInputStream(new FileInputStream(fileName))) {
            byte[] bytes = new byte[4];
            bin.read(bytes, 0, 4);
            if (!"00000801".equals(bytesToHex(bytes))) {
                throw new RuntimeException("Please select the correct file!");
            } else {
                bin.read(bytes, 0, 4);
                int number = Integer.parseInt(bytesToHex(bytes), 16);
                y = new double[number];
                for (int i = 0; i < number; i++) {
                    y[i] = bin.read();
                }
            }
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
        return y;
    }

    public static void main(String[] args) {
        double[][] images = getImages(TRAIN_IMAGES_FILE);
        double[] labels = getLabels(TRAIN_LABELS_FILE);

        double[][] images = getImages(TEST_IMAGES_FILE);
        double[] labels = getLabels(TEST_LABELS_FILE);

        System.out.println();
    }
}
  • 顯示影象程式碼
/**
 * draw a gray picture and the image format is JPEG.
 *
 * @param pixelValues pixelValues and ordered by column.
 * @param width       width
 * @param high        high
 * @param fileName    image saved file.
 */
public static void drawGrayPicture(int[] pixelValues, int width, int high, String fileName) throws IOException {
    BufferedImage bufferedImage = new BufferedImage(width, high, BufferedImage.TYPE_INT_RGB);
    for (int i = 0; i < width; i++) {
        for (int j = 0; j < high; j++) {
            int pixel = 255 - pixelValues[i * high + j];
            int value = pixel + (pixel << 8) + (pixel << 16);   // r = g = b 時,正好為灰度
            bufferedImage.setRGB(j, i, value);
        }
    }
    ImageIO.write(bufferedImage, "JPEG", new File(fileName));
}

3. 還有什麼

上面的讀取過程還是很簡單的。

想分享一下自己 code 的 bp(反向傳播)。希望有時間~~