1. 程式人生 > >利用softmax函式對mnist資料集簡單分類

利用softmax函式對mnist資料集簡單分類

mnist資料集的特點

  1. 每一張圖片包含28**28個畫素,我們把這一個陣列展開成一個向量,長度是28*28=784。因此在 MNIST訓練資料集中mnist.train.images 是一個形狀為 [60000, 784] 的張量,第一個維度數字用 來索引圖片,第二個維度數字用來索引每張圖片中的畫素點。圖片裡的某個畫素的強度值介於0-1 之間。
  2. MNIST資料集的標籤是介於0-9的數字,我們要把標籤轉化為“one-hot vectors”。一個onehot向量除了某一位數字是1以外,其餘維度數字都是0,比如標籤0將表示為([1,0,0,0,0,0,0,0,0,0]) ,標籤3將表示為([0,0,0,1,0,0,0,0,0,0]) 。
  3. 因此, mnist.train.labels 是一個 [60000, 10] 的數字矩陣。 例如,下面這幅圖,代表的數字為5042在這裡插入圖片描述

softmax函式:

  1. 我們知道MNIST的結果是0-9,我們的模型可能推測出一張圖片是數字9的概率是80%,是數字8 的概率是10%,然後其他數字的概率更小,總體概率加起來等於1。這是一個使用softmax迴歸模型的經典案例。softmax模型可以用來給不同的物件分配概率。 在這裡插入圖片描述

程式如下:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import  input_data

# 載入資料集
mnist = input_data.read_data_sets("MNIST_data", one_hot=True)

# 定義批次batch_size,一次性放入100張圖片
batch_size = 100
# 計算一個有多少個批次
n_batch = mnist.train.num_examples // batch_size

# 定義兩個placeholder
x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10])

# 建立一個簡單的神經網路
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros(10))
prediction = tf.nn.softmax(tf.matmul(x, W) + b)

# 二次代價函式
loss = tf.reduce_mean(tf.square(y - prediction))

# 使用梯度下降法
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)

# 初始化變數
init = tf.initialize_all_variables()

# 預測的結果
# tf.argmax()返回最大值所在的列
# 結果存放在一個bool型列表中
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1))

# 求準確率
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

with tf.Session() as sess:
    sess.run(init)
    for epoch in range(21):
        for batch in range(n_batch):
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)
            sess.run(train_step, feed_dict={x: batch_xs, y: batch_ys})
        acc = sess.run(accuracy, feed_dict={x:mnist.test.images, y: mnist.test.labels})
        print("Iter" + str(epoch) + ",Testing Accuracy" + str(acc))

執行結果如下:

Iter0,Testing Accuracy0.7488
Iter1,Testing Accuracy0.8331
Iter2,Testing Accuracy0.8592
Iter3,Testing Accuracy0.8707
Iter4,Testing Accuracy0.8779
Iter5,Testing Accuracy0.8814
Iter6,Testing Accuracy0.885
Iter7,Testing Accuracy0.8884
Iter8,Testing Accuracy0.8917
Iter9,Testing Accuracy0.8936
Iter10,Testing Accuracy0.8962
Iter11,Testing Accuracy0.8968
Iter12,Testing Accuracy0.8982
Iter13,Testing Accuracy0.8994
Iter14,Testing Accuracy0.9009
Iter15,Testing Accuracy0.9023
Iter16,Testing Accuracy0.9031
Iter17,Testing Accuracy0.9037
Iter18,Testing Accuracy0.9044
Iter19,Testing Accuracy0.9053
Iter20,Testing Accuracy0.9053

準確率大概在90%左右。