1. 程式人生 > >tensorflow實現logistic迴歸進行手寫字識別

tensorflow實現logistic迴歸進行手寫字識別

1.資料準備

import tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data

#載入mnist資料

mnist=input_data.read_data_sets('/data/machine_learning/mnist/',one_hot=True) #使用one-hot編碼

2.引數設定

#基本引數設定

learning_rate=0.01

training_epochs=25

batch_size=100

display_step=1

#設定輸出輸出資料佔位符

x=tf.placeholder(tf.float32,[None,784])#28*28=784

y=tf.placeholder(tf.float32,[None,10])#是個類別

#設定模型權重

W=tf.Variable(tf.zeros([784,10]))

b=tf.Variable(tf.zeros([10]))

3.模型構建

#構建模型

pred=tf.nn.softmax(tf.matmul(x,W)+b)

#使用交叉熵損失函式

cost=tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),reduction_indices=1))

#梯度下降最小化損失

optimizer=tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

#初始化全域性變數

init=tf.global_variables_initializer()

4.模型訓練與測試

#開始訓練

with tf.Session() as sess:

    sess.run(init)

    for epoch in range(training_epochs):

        avg_cost=0

        total_batch=int(mnist.train.num_examples/batch_size)

        for i in range(total_batch):

            batch_xs,batch_ys=mnist.train.next_batch(batch_size)

            _,c=sess.run([optimizer,cost],feed_dict={x:batch_xs,y:batch_ys})

            avg_cost+=c/total_batch

        if(epoch+1)%display_step==0:

            print("Epoch:",'%04d' %(epoch+1),'cost=',"{:.9f}".format(avg_cost))

    print("優化完成!")

    #模型測試

    correct_prediction=tf.equal(tf.argmax(pred,1),tf.argmax(y,1))

    #計算精度

    accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

    print("精度:",accuracy.eval({x:mnist.test.images,y:mnist.test.labels}))