使用 Java 讀取 MNIST 資料集
阿新 • • 發佈:2018-12-13
使用 Java 讀取 Mnist 資料集
0. 前言
好久沒寫 blog 了,沒有堅持住,心中滿滿的負罪感!!!
上週一時衝動了,決定自己 code 一下 mlp
(多層感知機)。最後的測試部分使用它來識別手寫數字,也就是在 MNIST
資料集上訓練並測試效果。在讀取 MNIST
資料集時本打算使用輪子,可並沒找到使用 Java
創造的輪子。於是,根據官網的儲存格式說明自己寫了一個。
遂得此文,望可拋磚引玉~~(廢話少說!)
1. MNIST 資料集
- THE MNIST DATABASE of handwritten digits
MachineLearing
- 其他資訊詳見官網(點選上面的小標題可以直接進入)
- 資料集的格式
- IMAGE FILE (以 train-images-idx3-ubyte 為例)
[offset] [type] [value] [description] 0000 32 bit integer 0x00000803(2051) magic number // 魔數,就像 java 類檔案中的 “CAFEBABE”。可視為一種驗證,其實沒有~~ 0004 32 bit integer 60000 number of images // 表明一共有 60000 中樣例 0008 32 bit integer 28 number of rows // 一行含有的畫素點數 0012 32 bit integer 28 number of columns // 一列含有的畫素點數 0016 unsigned byte ?? pixel // 對應畫素點的值(0 ~ 255) 0017 unsigned byte ?? pixel ........ xxxx unsigned byte ?? pixel
- LABEL FILE (以 train-labels-idx1-ubyte 為例)
[offset] [type] [value] [description] 0000 32 bit integer 0x00000801(2049) magic number (MSB first) // 同上 0004 32 bit integer 60000 number of items // 同上 0008 unsigned byte ?? label // 對應樣本的標籤,即對應影象中的手寫數字是幾(0 ~ 9) 0009 unsigned byte ?? label ........ xxxx unsigned byte ?? label
- IMAGE FILE (以 train-images-idx3-ubyte 為例)
2. 程式碼
- 讀取資料集程式碼
import java.io.BufferedInputStream;
import java.io.FileInputStream;
import java.io.IOException;
public class MnistRead {
public static final String TRAIN_IMAGES_FILE = "data/mnist/train-images.idx3-ubyte";
public static final String TRAIN_LABELS_FILE = "data/mnist/train-labels.idx1-ubyte";
public static final String TEST_IMAGES_FILE = "data/mnist/t10k-images.idx3-ubyte";
public static final String TEST_LABELS_FILE = "data/mnist/t10k-labels.idx1-ubyte";
/**
* change bytes into a hex string.
*
* @param bytes bytes
* @return the returned hex string
*/
public static String bytesToHex(byte[] bytes) {
StringBuffer sb = new StringBuffer();
for (int i = 0; i < bytes.length; i++) {
String hex = Integer.toHexString(bytes[i] & 0xFF);
if (hex.length() < 2) {
sb.append(0);
}
sb.append(hex);
}
return sb.toString();
}
/**
* get images of 'train' or 'test'
*
* @param fileName the file of 'train' or 'test' about image
* @return one row show a `picture`
*/
public static double[][] getImages(String fileName) {
double[][] x = null;
try (BufferedInputStream bin = new BufferedInputStream(new FileInputStream(fileName))) {
byte[] bytes = new byte[4];
bin.read(bytes, 0, 4);
if (!"00000803".equals(bytesToHex(bytes))) { // 讀取魔數
throw new RuntimeException("Please select the correct file!");
} else {
bin.read(bytes, 0, 4);
int number = Integer.parseInt(bytesToHex(bytes), 16); // 讀取樣本總數
bin.read(bytes, 0, 4);
int xPixel = Integer.parseInt(bytesToHex(bytes), 16); // 讀取每行所含畫素點數
bin.read(bytes, 0, 4);
int yPixel = Integer.parseInt(bytesToHex(bytes), 16); // 讀取每列所含畫素點數
x = new double[number][xPixel * yPixel];
for (int i = 0; i < number; i++) {
double[] element = new double[xPixel * yPixel];
for (int j = 0; j < xPixel * yPixel; j++) {
element[j] = bin.read(); // 逐一讀取畫素值
// normalization
// element[j] = bin.read() / 255.0;
}
x[i] = element;
}
}
} catch (IOException e) {
throw new RuntimeException(e);
}
return x;
}
/**
* get labels of `train` or `test`
*
* @param fileName the file of 'train' or 'test' about label
* @return
*/
public static double[] getLabels(String fileName) {
double[] y = null;
try (BufferedInputStream bin = new BufferedInputStream(new FileInputStream(fileName))) {
byte[] bytes = new byte[4];
bin.read(bytes, 0, 4);
if (!"00000801".equals(bytesToHex(bytes))) {
throw new RuntimeException("Please select the correct file!");
} else {
bin.read(bytes, 0, 4);
int number = Integer.parseInt(bytesToHex(bytes), 16);
y = new double[number];
for (int i = 0; i < number; i++) {
y[i] = bin.read();
}
}
} catch (IOException e) {
throw new RuntimeException(e);
}
return y;
}
public static void main(String[] args) {
double[][] images = getImages(TRAIN_IMAGES_FILE);
double[] labels = getLabels(TRAIN_LABELS_FILE);
double[][] images = getImages(TEST_IMAGES_FILE);
double[] labels = getLabels(TEST_LABELS_FILE);
System.out.println();
}
}
- 顯示影象程式碼
/**
* draw a gray picture and the image format is JPEG.
*
* @param pixelValues pixelValues and ordered by column.
* @param width width
* @param high high
* @param fileName image saved file.
*/
public static void drawGrayPicture(int[] pixelValues, int width, int high, String fileName) throws IOException {
BufferedImage bufferedImage = new BufferedImage(width, high, BufferedImage.TYPE_INT_RGB);
for (int i = 0; i < width; i++) {
for (int j = 0; j < high; j++) {
int pixel = 255 - pixelValues[i * high + j];
int value = pixel + (pixel << 8) + (pixel << 16); // r = g = b 時,正好為灰度
bufferedImage.setRGB(j, i, value);
}
}
ImageIO.write(bufferedImage, "JPEG", new File(fileName));
}
3. 還有什麼
上面的讀取過程還是很簡單的。
想分享一下自己 code 的 bp
(反向傳播)。希望有時間~~