JAVAEE與人工智慧實戰之--通過MNIST進行模型訓練
MNIST簡介
一個手寫數字識別庫,世界上最權威的,美國郵政系統開發的,手寫內容是0-9的內容,手寫內容採集於美國人口調查局的員工和高中生。包括6萬張訓練圖片和1萬張測試圖片構成的,每張圖片都是28*28大小,而且都是黑白色構成。
MINIST實驗包含了四個檔案,其中train-images-idx3-ubyte是60000個圖片樣本,train-labels-idx1-ubyte是這60000個圖片對應的數字標籤,t10k-images-idx3-ubyte是用於測試的樣本,t10k-labels-idx1-ubyte是測試樣本對應的數字標籤。
我們以測試集中的一個圖片為例來說明圖片的儲存形式:
MNIST圖片並不是傳統意義上的png或者jpg格式的圖片,因為png或者jpg的圖片格式,會帶有很多幹擾資訊(如:資料塊,圖片頭,圖片尾,長度等等),這些圖片會被處理成很簡易的陣列,圖片長度為28,寬度也為28,總畫素為28 28=784,在MNIST儲存的就是一個長度為784的陣列,陣列中的每個值表示每個點的RGB值,其中黑色用0表示、白色用255表示。我們可以將陣列轉成28 28的二維陣列,如下圖所示,可以看出這是一個表示的是數字5的圖片。

image.png
如果把畫素寫成圖片,圖片是這樣的:

image.png
通過MNIST訓練模型
在BP神經網路中, 層數、節點個數、學習速率、訓練集、訓練次數,都會影響到最終模型的泛化能力。因此,在設計模型時,節點的個數,學習速率的大小,以及訓練次數都是需要考慮的。
本例項中設定神經網路層數為3層,其中輸入特徵為784個,每層節點數分別為300、100、10個,學習速率設定為0.5,迭代週期為30,批量設定60個。通過訓練該模型在MNIST測試集上的平均準確率為96.68 %左右。
public static void main(String[] args) { //三層網路,各層節點數為784*300*10 輸入特徵 784個隱藏層節點300個 輸出層節點10個 int[] nodeNum = {784, 300,100, 10}; //週期被定義為向前和向後傳播中所有批次的單次訓練迭代。 int epoch = 30; //每次批量的樣本數 int batchSize = 60; double learningRate=0.5; NetTrainAndTest.train(nodeNum, epoch, batchSize,learningRate); }
對模型進行序列化
為了“一次訓練、多次使用”,我們對訓練好的模型進行序列化儲存,後續即可通過反序列化的方式讀取恢復模型。
/** * 通過序列化方式儲存模型 * * @param fileName 模型存放的檔名 */ public static <T> void saveModel(String fileName, T obj) { try (BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(fileName)); ObjectOutputStream oos = new ObjectOutputStream(bos)) { oos.writeObject(obj); } catch (IOException e) { throw new RuntimeException(e); } } /** * 恢復模型 * * @param fileName 模型持久化的存放位置 檔名 *<p> *//@SuppressWarnings("unchecked") */ public static <T> T restoreModel(String fileName) { try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(fileName)); ObjectInputStream ois = new ObjectInputStream(bis)) { return (T) ois.readObject(); } catch (IOException | ClassNotFoundException e) { throw new RuntimeException(e); } }
上一篇 | JAVAEE與人工智慧目錄 | [下一篇] |
---|