1. 程式人生 > >TensorFlow(八) TensorFlow圖像識別(KNN)

TensorFlow(八) TensorFlow圖像識別(KNN)

nump session trac inf dict sha ceil dom 數據

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from sklearn import  datasets
import random
from PIL import Image

from tensorflow.examples.tutorials.mnist import  input_data

sess=tf.Session()
mnist= input_data.read_data_sets("MNIST_data/",one_hot=True)
#本例包含10個類別
train_size=1000
test_size
=102 rand_train_indices=np.random.choice(len(mnist.train.images),train_size,replace=False) rand_test_indices=np.random.choice(len(mnist.train.images),test_size,replace=False) x_vals_train=mnist.train.images[rand_train_indices] x_vals_test=mnist.train.images[rand_test_indices] y_vals_train=mnist.train.labels[rand_train_indices] y_vals_test
=mnist.train.labels[rand_test_indices] k=4 batch_size=6 x_data_train=tf.placeholder(shape=[None,784],dtype=tf.float32) x_data_test=tf.placeholder(shape=[None,784],dtype=tf.float32) y_target_train=tf.placeholder(shape=[None,10],dtype=tf.float32) y_target_test=tf.placeholder(shape=[None,10],dtype=tf.float32)
#L1距離 shape=(6, 1000) sub.shape=(1000,784) - (6,1,10)=(6,1000,784) distance=tf.reduce_sum(tf.abs(tf.subtract(x_data_train,tf.expand_dims(x_data_test,1))),reduction_indices=2) #top K (6, 4) top_k_xvals,top_k_indices=tf.nn.top_k(tf.negative(distance),k=k) #(6, 4, 10) = gather((1000,10),(6,4) ) prediction_indices=tf.gather(y_target_train,top_k_indices) #shape=(6, 10) count_of_prediction=tf.reduce_sum(prediction_indices,reduction_indices=1) #預測模型 shape=(6,) prediction=tf.arg_max(count_of_prediction,dimension=1) num_loop=int(np.ceil(len(x_vals_test)/batch_size)) test_output=[] actual_vals=[] for i in range(num_loop): min_index=i*batch_size max_index=min((i+1)*batch_size,len(x_vals_test)) #獲取數據 x_batch=x_vals_test[min_index:max_index] y_batch = y_vals_test[min_index:max_index] predictions=sess.run(prediction,feed_dict={x_data_test:x_batch,x_data_train:x_vals_train,y_target_test:y_batch,y_target_train:y_vals_train}) test_output.extend(predictions) actual_vals.extend(np.argmax(y_batch,axis=1)) #精確度預測 accuracy=sum( 1./test_size for i in range(test_size) if test_output[i]==actual_vals[i]) print("Accuarcy: "+str(accuracy)) actuals=np.argmax(y_batch,axis=1) for i in range(len(actuals)): plt.subplot(2,3,i+1) plt.imshow(np.reshape(x_batch[i],[28,28]),cmap="Greys_r") plt.title(Actual: +str(actuals[i])+ Pred:+str(predictions[i]),fontsize=10) frame=plt.gca() frame.axes.get_xaxis().set_visible(False) frame.axes.get_yaxis().set_visible(False) plt.show()

技術分享圖片

TensorFlow(八) TensorFlow圖像識別(KNN)