1. 程式人生 > >tensorflow實現多層感知機進行手寫字識別

tensorflow實現多層感知機進行手寫字識別

logits=multilayer_perceptron(X)

#使用交叉熵損失

loss_op=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits,labels=Y))

#定義優化器

optimizer=tf.train.AdamOptimizer(learning_rate=learning_rate)

train_op=optimizer.minimize(loss_op)#使用優化器最小化損失

init=tf.global_variables_initializer()#變數初始化器

with tf.Session()as sess:

    sess.run(init)

    for epoch in range(training_epochs):

        avg_cost=0

        total_batch=int(mnist.train.num_examples/batch_size)

        for i in range(total_batch):

            batch_x,batch_y=mnist.train.next_batch(batch_size)

            _,c=sess.run([train_op,loss_op],feed_dict={X:batch_x,Y:batch_y})

            avg_cost+=c/total_batch #計算平均損失

        if epoch%display_step==0:

            print("Epoch:","%04d"%(epoch+1),"cost={:.9f}".format(avg_cost))

    print("優化完成!")

    pred=tf.nn.softmax(logits=logits)

    correct_prediction=tf.equal(tf.argmax(pred,1),tf.argmax(Y,1))

    accuracy=tf.reduce_mean(tf.cast(correct_prediction,"float"))

    print("Accuracy:",accuracy.eval({X:mnist.test.images,Y:mnist.test.labels})) #使用測試集來驗證模型精度