1. 程式人生 > >機器學習(2) - KNN識別MNIST

機器學習(2) - KNN識別MNIST

min lose fse skip show turn ESS 行數 sna

代碼

https://github.com/s055523/MNISTTensorFlowSharp

數據的獲得

數據可以由http://yann.lecun.com/exdb/mnist/下載。之後,儲存在trainDir中,下次就不需要下載了。

技術分享圖片
/// <summary>
        /// 如果文件不存在就去下載
        /// </summary>
        /// <param name="urlBase">下載地址</param>
        /// <param name="trainDir">文件目錄地址</param>
/// <param name="file">文件名</param> /// <returns></returns> public static Stream MaybeDownload(string urlBase, string trainDir, string file) { if (!Directory.Exists(trainDir)) { Directory.CreateDirectory(trainDir); }
var target = Path.Combine(trainDir, file); if (!File.Exists(target)) { var wc = new WebClient(); wc.DownloadFile(urlBase + file, target); } return File.OpenRead(target); }
View Code

數據格式處理

下載下來的文件共有四個,都是擴展名為gz的壓縮包。

train-images-idx3-ubyte.gz 55000張訓練圖片和5000張驗證圖片

train-labels-idx1-ubyte.gz 訓練圖片對應的數字標簽(即答案)

t10k-images-idx3-ubyte.gz 10000張測試圖片

t10k-labels-idx1-ubyte.gz 測試圖片對應的數字標簽(即答案)

處理圖片數據壓縮包

每個壓縮包的格式為:

偏移量

類型

意義

0

Int32

2051或2049

一個定死的魔術數。用來驗證該壓縮包是訓練集(2051)或測試集(2049)

4

Int32

60000或10000

壓縮包的圖片數

8

Int32

28

每個圖片的行數

12

Int32

28

每個圖片的列數

16

Unsigned byte

0 - 255

第一張圖片的第一個像素

17

Unsigned byte

0 - 255

第一張圖片的第二個像素

因此,我們可以使用一個統一的方式將數據處理。我們只需要那些圖片像素。

技術分享圖片
/// <summary>
        /// 從數據流中讀取下一個int32
        /// </summary>
        /// <param name="s"></param>
        /// <returns></returns>
        int Read32(Stream s)
        {
            var x = new byte[4];
            s.Read(x, 0, 4);
            return DataConverter.BigEndian.GetInt32(x, 0);
        }

        /// <summary>
        /// 處理圖片數據
        /// </summary>
        /// <param name="input"></param>
        /// <param name="file"></param>
        /// <returns></returns>
        MnistImage[] ExtractImages(Stream input, string file)
        {
            //文件是gz格式的
            using (var gz = new GZipStream(input, CompressionMode.Decompress))
            {
                //不是2051說明下載的文件不對
                if (Read32(gz) != 2051)
                {
                    throw new Exception("不是2051說明下載的文件不對: " + file);
                }
                //圖片數
                var count = Read32(gz);
                //行數
                var rows = Read32(gz);
                //列數
                var cols = Read32(gz);

                Console.WriteLine($"準備讀取{count}張圖片。");

                var result = new MnistImage[count];
                for (int i = 0; i < count; i++)
                {
                    //圖片的大小(每個像素占一個bit)
                    var size = rows * cols;
                    var data = new byte[size];

                    //從數據流中讀取這麽大的一塊內容
                    gz.Read(data, 0, size);

                    //將讀取到的內容轉換為MnistImage類型
                    result[i] = new MnistImage(cols, rows, data);
                }
                return result;
            }
        }
View Code

準備一個MnistImage類型:

技術分享圖片
/// <summary>
    /// 圖片類型
    /// </summary>
    public struct MnistImage
    {
        public int Cols, Rows;
        public byte[] Data;
        public float[] DataFloat;

        public MnistImage(int cols, int rows, byte[] data)
        {
            Cols = cols;
            Rows = rows;
            Data = data;
            DataFloat = new float[data.Length];
            for (int i = 0; i < data.Length; i++)
            {
                //數據歸一化(這裏將0-255除255變成了0-1之間的小數)
                //也可以歸一為-0.5到0.5之間
                DataFloat[i] = Data[i] / 255f;
            }
        }
    }
View Code

這樣一來,圖片數據就處理完成了。

處理數字標簽數據壓縮包

數字標簽數據壓縮包和圖片數據壓縮包的格式類似。

偏移量

類型

意義

0

Int32

2051或2049

一個定死的魔術數。用來驗證該壓縮包是訓練集(2051)或測試集(2049)

4

Int32

60000或10000

壓縮包的數字標簽數

5

Unsigned byte

0 - 9

第一張圖片對應的數字

6

Unsigned byte

0 - 9

第二張圖片對應的數字

它的處理更加簡單。

技術分享圖片
/// <summary>
        /// 處理標簽數據
        /// </summary>
        /// <param name="input"></param>
        /// <param name="file"></param>
        /// <returns></returns>
        byte[] ExtractLabels(Stream input, string file)
        {
            using (var gz = new GZipStream(input, CompressionMode.Decompress))
            {
                //不是2049說明下載的文件不對
                if (Read32(gz) != 2049)
                {
                    throw new Exception("不是2049說明下載的文件不對:" + file);
                }
                var count = Read32(gz);
                var labels = new byte[count];

                gz.Read(labels, 0, count);

                return labels;
            }
        }
View Code

將數字標簽轉化為二維數組:one-hot編碼

由於我們的數字為0-9,所以,可以視為有十個class。此時,為了後續的處理方便,我們將數字標簽轉化為數組。因此,一組標簽就轉換為了一個二維數組。

例如,標簽0變成[1,0,0,0,0,0,0,0,0,0]