tensorflow 學習專欄(四):使用tensorflow在mnist資料集上使用邏輯迴歸logistic Regression進行分類
阿新 • • 發佈:2019-02-20
在面對分類問題時,我們常用的一個演算法便是邏輯迴歸(logistic Regression)
在本次實驗中,我們的實驗物件是mnist手寫資料集,在該資料集中每張影象包含28*28個畫素點如下圖所示:
我們使用邏輯迴歸演算法來對mnist資料集的資料進行分類,判斷影象所表示的數字是幾。
程式碼如下:
import tensorflow as tf import numpy as np import matplotlib.pyplot as plt from tensorflow.examples.tutorials.mnist import input_data tf.set_random_seed(1) np.random.seed(1) BATCH_SIZE = 50 LR = 0.001 mnist = input_data.read_data_sets('./mnist',one_hot=True) #匯入MNIST資料集 test_x = mnist.test.images[:2000] #將MNIST.TEST前2000個數據設定為測試資料集 test_y = mnist.test.labels[:2000] x = tf.placeholder(tf.float32,[None,784])/255. y = tf.placeholder(tf.int32,[None,10]) def addlayer(input,in_size,out_size,activiation_function=None): #定義addlayer函式 Weight = tf.Variable(tf.zeros([in_size,out_size])) Baise = tf.Variable(tf.zeros([out_size])) wx_b = tf.matmul(input,Weight)+Baise if activiation_function is None: out = wx_b else: out = activiation_function(wx_b) return out #build model pred = addlayer(x,784,10,tf.nn.softmax) #構建模型 loss = tf.losses.softmax_cross_entropy(onehot_labels=y,logits=pred) #計算誤差 train = tf.train.AdamOptimizer(LR).minimize(loss) #訓練優化 accuracy = tf.metrics.accuracy(labels=tf.argmax(y,axis=1),predictions=tf.argmax(pred,axis=1),)[1] #計算準確率 sess = tf.Session() #初始化 sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) for step in range(10000): #訓練 b_x,b_y = mnist.train.next_batch(BATCH_SIZE) _,loss_ = sess.run([train,loss],feed_dict={x:b_x,y:b_y}) if step%50==0: accuracy_ = sess.run(accuracy,feed_dict={x:test_x,y:test_y}) print('train loss:%.4f'%loss_, '|test accuracy%.4f'%accuracy_) for i in range(5): #將test資料集前5個數據進行視覺化 X = test_x[i][np.newaxis,:] Y = test_y[i] test_output = sess.run(pred,feed_dict={x:X}) pred_y = np.argmax(test_output,axis=1) real_y = np.argmax(Y) img = X.reshape((28,28)) plt.imshow(img,cmap='gray') plt.text(1.5,2.5,'real number=%.4f'%real_y,fontdict={'size':20,'color':'green'}) plt.text(1.5,5,'pred number=%.4f'%pred_y,fontdict={'size':20,'color':'red'}) plt.show()
訓練結果如下:
測試資料集中前五個資料視覺化結果如下: