1. 程式人生 > >簡單易學的機器學習演算法——K-近鄰演算法

簡單易學的機器學習演算法——K-近鄰演算法

# coding:UTF-8

import cPickle as pickle
import gzip
import numpy as np

def load_data(data_file):
	with gzip.open(data_file, 'rb') as f:
		train_set, valid_set, test_set = pickle.load(f)
	return train_set[0], train_set[1], test_set[0], test_set[1]

def cal_distance(x, y):
	return ((x - y) * (x - y).T)[0, 0]

def get_prediction(train_y, result):
	result_dict = {}
	for i in xrange(len(result)):
		if train_y[result[i]] not in result_dict:
			result_dict[train_y[result[i]]] = 1
		else:
			result_dict[train_y[result[i]]] += 1
	predict = sorted(result_dict.items(), key=lambda d: d[1])
	return predict[0][0]

def k_nn(train_data, train_y, test_data, k):
	# print test_data
	m = np.shape(test_data)[0]  # 需要計算的樣本的個數
	m_train = np.shape(train_data)[0]
	predict = []
	
	for i in xrange(m):
		# 對每一個需要計算的樣本計算其與所有的訓練資料之間的距離
		distance_dict = {}
		for i_train in xrange(m_train):
			distance_dict[i_train] = cal_distance(train_data[i_train, :], test_data[i, :])
		# 對距離進行排序,得到最終的前k個作為最終的預測
		distance_result = sorted(distance_dict.items(), key=lambda d: d[1])
		# 取出前k個的結果作為最終的結果
		result = []
		count = 0
		for x in distance_result:
			if count >= k:
				break
			result.append(x[0])
			count += 1
		# 得到預測
		predict.append(get_prediction(train_y, result))
	return predict

def get_correct_rate(result, test_y):
	m = len(result)
	
	correct = 0.0
	for i in xrange(m):
		if result[i] == test_y[i]:
			correct += 1
	return correct / m	

if __name__ == "__main__":
	# 1、匯入
	print "---------- 1、load data ------------"
	train_x, train_y, test_x, test_y = load_data("mnist.pkl.gz")
	# 2、利用k_NN計算	
	train_x = np.mat(train_x)
	test_x = np.mat(test_x)
	print "---------- 2、K-NN -------------"
	result = k_nn(train_x, train_y, test_x[:10,:], 10)
	print result
	# 3、預測正確性
	print "---------- 3、correct rate -------------"
	print get_correct_rate(result, test_y)