TensorFlow實戰之Softmax Regression識別手寫數字
關於本文說明,本人原博客地址位於http://blog.csdn.net/qq_37608890,本文來自筆者於2018年02月21日 23:10:04所撰寫內容(http://blog.csdn.net/qq_37608890/article/details/79343860)。
本文根據最近學習TensorFlow書籍網絡文章的情況,特將一些學習心得做了總結,詳情如下.如有不當之處,請各位大拿多多指點,在此謝過。
一、相關概念
1、MNIST
MNIST(Mixed National Institute of Standards and Technology database),作為一個常見的數據集
-
Training set images: train-images-idx3-ubyte.gz (9.9 MB, 解壓後 47 MB, 包含 60,000 個樣本)
-
Training set labels: train-labels-idx1-ubyte.gz (29 KB, 解壓後 60 KB, 包含 60,000 個標簽)
-
Test set images: t10k-images-idx3-ubyte.gz (1.6 MB, 解壓後 7.8 MB, 包含 10,000 個樣本)
-
Test set labels: t10k-labels-idx1-ubyte.gz (5KB, 解壓後 10 KB, 包含 10,000 個標簽)
每一個訓練元素都是28*28像素的手寫數字圖片,只有灰度值信息,空白部分為0,筆跡根據顏色深淺取[0, 1], 784維,丟棄二維空間信息,目標分0~9共10類。
2、One-Hot編碼
在我們機器學習應用任務的實現過程中,針對有些非連續的數據,我們也會考慮使用數字來進行編碼。例如“女人”編碼為1,“男人”編碼為2,即便如此,二者在數學上不存在連續關系,但是在機器學習算法中,會認為“女人”和“男人”之間存在著數學上的有序關系。
One-Hot編碼:獨熱編碼,又被稱為一位有效編碼,其方法是使用N位狀態寄存器來對N個狀態進行編碼,任意一個狀態都有它獨立的寄存器位,並且在任意時候只有一位有效。例如上文中說的“女人”和“男人”共有兩種狀態,那麽就可以編碼為01和10,對於有N個狀態的特征,經過one-hot編碼後就會變成N個二元值,而其中只有一個為1。
主要優點如下:
-
解決了分類器不好處理屬性數據的問題;
-
在一定程度上也起到了擴充特征的作用;
3、Softmax回歸
在 logistic 回歸中,我們的訓練集由 m 個已標記的樣本構成: ,其中輸入特征。(我們對符號的約定如下:特征向量 的維度為,其中 對應截距項 。) 由於 logistic 回歸是針對二分類問題的,因此類標記。假設函數(hypothesis function) 如下:
將訓練模型參數 \textstyle \theta,使其能夠最小化代價函數 :
在 softmax回歸中,我們解決的是多分類問題(相對於 logistic 回歸解決的二分類問題),類標 可以取 個不同的值(而不是 2 個)。因此,對於訓練集 ,我們有 。(註意此處的類別下標從 1 開始,而不是 0)。例如,在 MNIST 數字識別任務中,我們有 個不同的類別。
對於給定的測試輸入,我們想用假設函數針對每一個類別j估算出概率值 。也就是說,我們想估計 的每一種分類結果出現的概率。因此,我們的假設函數將要輸出一個 維的向量(向量元素的和為1)來表示這 個估計的概率值。 具體地說,我們的假設函數 形式如下:
其中 是模型的參數。請註意 這一項對概率分布進行歸一化,使得所有概率之和為 1 。
為了方便起見,我們同樣使用符號 來表示全部的模型參數。在實現Softmax回歸時,將 用一個的矩陣來表示會很方便,該矩陣是將 按行羅列起來得到的,如下所示:
二、案例一Softmax回歸實現
1、簡要概述
截止目前,我們已經知道了Logistic函數只能被使用在二分類問題中,但是它的多項式回歸,即softmax函數,可以解決多分類問題。假設softmax函數?的輸入數據是C維度的向量z,那麽softmax函數的數據也是一個C維度的向量y,裏面的值是0到1之間。softmax函數其實就是一個歸一化的指數函數,定義如下:
式子中的分母充當了正則項的作用,可以使得
作為神經網絡的輸出層,softmax函數中的值可以用C個神經元來表示。
對於給定的輸入z,我們可以得到每個分類的概率t = c for c = 1 ... C可以表示為:
其中,P(t=c|z)表示,在給定輸入z時,該輸入數據是c分類的概率。
下圖展示了在一個二分類(t = 1, t = 2)中,輸入向量是z = [z1, z2],那麽輸出概率P(t=1|z)如下圖所示。
2、代碼實現過程如下
#Softmax分類函數及其應用代碼實現 import numpy as np import matplotlib.pyplot as plt from matplotlib.colors import colorConverter,ListedColormap from mpl_toolkits.mplot3d import Axes3D from matplotlib import cm %matplotlib inline #定義Softmax函數 def softmax(z): return np.exp(z)/np.sum(np.exp(z)) #展示在一個二分類(t=1,t=2)中,輸入向量是z=[z1,z2], #那麽輸出概率為P(t=1|Z)的情況。 nb_of_zs = 200 zs = np.linspace(-10,10,num=nb_of_zs) zs_1, zs_2 = np.meshgrid(zs, zs) y = np.zeros((nb_of_zs,nb_of_zs,2)) for i in range(nb_of_zs): for j in range(nb_of_zs): y[i,j,:] = softmax(np.asarray([zs_1[i,j],zs_2[i,j]])) fig = plt.figure() ax = fig.gca(projection=‘3d‘) surf = ax.plot_surface(zs_1,zs_2,y[:,:,0],linewidth =0, cmap=cm.coolwarm) ax.view_init(elev=30,azim=70) cbar = fig.colorbar(surf) ax.set_xlabel(‘$z_1$‘, fontsize=15) ax.set_ylabel(‘$z_2$‘, fontsize=15) ax.set_zlabel(‘$z_1$‘, fontsize=15) ax.set_title(‘$P(t=1|\mathbf{z})$‘) cbar.ax.set_ylabel(‘$P(t=1|\mathbf{z})$‘, fontsize=15) plt.grid() plt.show()
最終生成圖像如下:
3、Softmax回歸模型參數化的特點
Softmax 回歸有一個不尋常的特點:它有一個“冗余”的參數集。為了便於闡述這一特點,假設我們從參數向量 中減去了向量 ,這時,每一個 都變成了 ()。此時假設函數變成了以下的式子:
換句話說,從 中減去完全不影響假設函數的預測結果!這表明前面的 softmax 回歸模型中存在冗余的參數。更正式一點來說, Softmax 模型被過度參數化了。對於任意一個用於擬合數據的假設函數,可以求出多組參數值,這些參數得到的是完全相同的假設函數 。
進一步而言,如果參數 是代價函數 的極小值點,那麽 同樣也是它的極小值點,其中 可以為任意向量。因此使 最小化的解不是唯一的。(有趣的是,由於 仍然是一個凸函數,因此梯度下降時不會遇到局部最優解的問題。但是 Hessian 矩陣是奇異的/不可逆的,這會直接導致采用牛頓法優化就遇到數值計算的問題)。
註意,當 時,我們總是可以將 替換為(即替換為全零向量),並且這種變換不會影響假設函數。因此我們可以去掉參數向量 (或者其他 中的任意一個)而不影響假設函數的表達能力。實際上,與其優化全部的個參數 (其中 ),我們可以令 ,只優化剩余的 個參數,這樣算法依然能夠正常工作。
在實際應用中,為了使算法實現更簡單清楚,往往保留所有參數 ,而不任意地將某一參數設置為 0。但此時我們需要對代價函數做一個改動:加入權重衰減。權重衰減可以解決 softmax 回歸的參數冗余所帶來的數值問題。
三、TensorFlow實現Softmax Regression識別手寫數字
1、項目背景
MNIST(Mixed National Institute of Standards and Technology database),簡單機器視覺數據集,由幾萬張28X28像素的手寫數字組成,這些圖片只包含灰度值信息,空白部分為0,筆跡根據顏色深淺取[0, 1], 784維,我們的目標是對這些手寫數字的圖片進行分類,轉化成0~9共10類。
2、MNIST手寫數字圖片示例圖
3、算法結構特點
-
使用Softmax Regression分類模型進行分類。
-
只有輸入層和輸出層,沒有隱含層。
4、TensorFlow 實現簡單機器算法步驟
-
定義算法公式,神經網絡forward計算。
-
定義loss,選定優化器,指定優化器優化loss。
-
叠代訓練數據。
-
測試集、驗證集評測準確率。
5、實現過程
Softmax函數
計算過程可視化如下
具體代碼實現如下
#調用相關數據
from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) #展示訓練集、測試集、驗證集樣本 print(mnist.train.images.shape, mnist.train.labels.shape) print(mnist.test.images.shape, mnist.test.labels.shape) print(mnist.validation.images.shape, mnist.validation.labels.shape) #圖像展示 import numpy as np import matplotlib.pyplot as plt #imshow data imgTol = mnist.train.images img = np.reshape(imgTol[1,:],[28,28]) plt.show()
圖像如下
繼續執行後續代碼,查看Softmax Regression模型的效果情況
import tensorflow as tf sess = tf.InteractiveSession() x=tf.placeholder(tf.float32, [None,784]) W =tf.Variable(tf.zeros([784,10])) b=tf.Variable(tf.zeros([10])) y = tf.nn.softmax(tf.matmul(x,W)+b) y_ =tf.placeholder(tf.float32, [None, 10]) cross_entropy =tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y), reduction_indices=[1])) train_step =tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) tf.global_variables_initializer().run() for i in range(1000): batch_xs, batch_ys = mnist.train.next_batch(100) train_step.run({x: batch_xs, y_:batch_ys}) correct_prediction =tf.equal(tf.argmax(y,1),tf.argmax(y_,1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) print(accuracy.eval({x: mnist.test.images, y_: mnist.test.labels}))
關於執行準確率情況,筆者測試了7次,結果不盡相同,基本都是0.92左右。
第一次執行結果:0.9216;第二次三次執行結果:0.9171;第四次執行結果:0.9216;第五次執行結果:0.9193;第六次:0.9219;第七次:0.9165。
四、小結
本文涉及TensorFlow實現了一個簡單的機器學習算法Softmax Regression,是一個沒有隱含層的最淺的神經網絡,整個流程在第三部分也提到,這裏再次羅列出來,如下:
- 定義算法公式,神經網絡forward計算。
- 定義loss,選定優化器,指定優化器優化loss。
- 叠代訓練數據。
- 測試集、驗證集評測準確率。
這四部分是使用TensorFlow進行算法設計、訓練的核心流程,會貫穿神經網絡的各類應用。需要提醒的是,我們定義的各個公式其實只是Computation Graph,在執行該行代碼時,計算還沒有實際發生,只有等調用run方法,並feed數據時計算才真正執行。例如cross_entropy、trian_step、accuracy等都是計算圖中的節點,而並不是數據結果,可以通過調用run方法執行這些節點或者講運算操作來獲取結果。
至於第三部分Softmax Regression達到的效果,92%的準確率還不錯,但還達不到實用的程度。手寫數字的識別主要應用在銀行等金融領域,如果準確率不夠高,引起的後果將會非常嚴重。後續文章中,會從感知機、卷積神經網絡的角度解決MNIST手寫數字識別問題。
關於使用TensorFlow來實現Softmax Regression識別手寫數字的撰寫,暫時先到此。
主要參考資料《TensorFlow實戰》(黃文堅 唐源 著)(電子工業出版社)
TensorFlow實戰之Softmax Regression識別手寫數字