1. 程式人生 > >tensorflow100天—第5天:最近鄰演算法

tensorflow100天—第5天:最近鄰演算法

python程式碼

import tensorflow as tf 
import numpy as np 

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('/tmp/data/', one_hot=True)

xtrain, ytrain = mnist.train.next_batch(2000)       # 用5000個樣本做容器物件
xtest, ytest = mnist.test.next_batch(200)           # 200個測試樣本
xi = tf.placeholder(tf.float32, [None, 784]) # 數值可變的變數都要宣告成Variable yi = tf.placeholder(tf.float32,784) distance = tf.reduce_sum(tf.abs(tf.add(xi, tf.negative(yi) )), reduction_indices=1 ) pred = tf.argmin(distance, 0) # 最短的距離 init = tf.global_variables_initializer() acc = 0.0 with tf.Session(
) as sess: sess.run(init) for i in range(len(xtest)): sample = xtest[i,:] pres_index = sess.run(pred, feed_dict={xi:xtrain, yi:sample }) pred_label = np.argmax(ytrain[pres_index]) target = np.argmax(ytest[i]) print('sample:{} | pred:{} | target:{}'.format
(i, pred_label, target)) if pred_label == target: acc += 1/len(xtest) print('finished, final accuracy is:{}'.format(acc))

總結

  1. tensorflow可以只進行前向計算,比如本例,相當於就是numpy