1. 程式人生 > >用TensorFlow基於最近鄰域法實現影象識別

用TensorFlow基於最近鄰域法實現影象識別

1、匯入程式設計庫

import random
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as  plt
from PIL import Image
from tensorflow.examples.tutorials.mnist import input_data

2、建立會話,載入資料集

sess = tf.Session()
mnist = input_data.read_data_sets("E:\Python Project\mnist\MNIST_data", one_hot=True
)

3、分割資料集

train_size = 100
test_size = 102
rand_train_indices = np.random.choice(len(mnist.train.images),train_size,replace =False)

rand_test_indices = np.random.choice(len(mnist.test.images),test_size,replace = False)
x_vals_train = mnist.train.images[rand_train_indices]
x_vals_test = mnist.test.images[
rand_test_indices] y_vals_train = mnist.train.labels[rand_train_indices] y_vals_test = mnist.train.labels[rand_test_indices]

4、宣告K值,批量大小,佔位符等

k = 4
batch_size = 6
x_data_train = tf.placeholder(shape=[None, 784], dtype = tf.float32)
x_data_test = tf.placeholder(shape=[None, 784], dtype = tf.float32)
y_target_train = tf.placeholder(shape=[None,10], dtype = tf.float32) y_target_test = tf.placeholder(shape=[None,10], dtype = tf.float32)

5、宣告距離度量函式

distance = tf.reduce_sum(tf.abs(tf.subtract(x_data_train, tf.expand_dims(x_data_test,1))), reduction_indices=2)

6、找到最接近的top k圖片和預測模型

top_x_xvals, top_k_indices = tf.nn.top_k(tf.negative(distance),k=k)
prediction_indices = tf.gather(y_target_train, top_k_indices)
count_of_predictions = tf.reduce_sum(prediction_indices, axis=1)
prediction = tf.argmax(count_of_predictions, axis=1)

7、遍歷迭代,計算預測值,並將結果儲存

test_output = []
actual_vals = []

for i in range(num_loops):
    min_index = i * batch_size
    max_index = min((i+1)*batch_size,len(x_vals_train))
    x_batch = x_vals_test[min_index:max_index]
    y_batch = y_vals_test[min_index:max_index]
    predictions = sess.run(prediction, feed_dict={x_data_train:x_vals_train,x_data_test:x_batch,y_target_train:y_vals_train,y_target_test:y_batch})
    test_output.extend(predictions)
    actual_vals.extend(np.argmax(y_batch, axis=1))
    

8、計算準確度

accuarcy = sum([1./test_size for i in range(test_size) if test_output[i] == actual_vals[i]])

9、繪製最後批次的計算結果

actuals = np.argmax(y_batch, axis=1)

Nrows = 2
Ncols = 3
for i in range(len(actuals)):
    plt.subplot(Nrows, Ncols, i+1)
    plt.imshow(np.reshape(x_batch[i], [28,28]), cmap='Greys_r')
    plt.title('Actual: ' + str(actuals[i]) + ' Pred: ' + str(predictions[i]),
                               fontsize=10)
    frame = plt.gca()
    frame.axes.get_xaxis().set_visible(False)
    frame.axes.get_yaxis().set_visible(False)
    
plt.show()

9、執行結果 在這裡插入圖片描述