1. 程式人生 > >tensorflow SVM 線性可分資料分類

tensorflow SVM 線性可分資料分類

引言

對於SVM,具體的可以參考其他部落格。我覺得SVM裡面的數學知識不好懂,特別是拉格朗日乘子法和後續的 FTT 條件。一個簡單的從整體上先把握的方法就是:不要管怎麼來的,知道後面是一個二次優化就行了。 此處還是用 iris 資料集的萼片長度和花瓣寬度來對鳶尾花分類。 所用的損失函式是 在這裡插入圖片描述 其中 n 是每次訓練的資料量,就是下面程式碼中的 batch_size ,A、b是要優化的變數,A是係數,b是偏差。yi 是理論輸出(-1 或 1)。a(阿爾法,打不出)是權重,自己設,至於怎麼設就根據經驗了。

結果展示

在這裡插入圖片描述

程式碼

# 匯入庫
import tensorflow as tf 
import numpy as np 
import matplotlib.pyplot as plt 
from sklearn import datasets

sess = tf.Session()

# 設定隨機種子,程式碼結束後有討論
np.random.seed(7)
tf.set_random_seed(8)

iris = datasets.load_iris()
x_vals = np.array([[x[0], x[3]] for x in iris.data])
y_vals = np.array([1 if y==0 else -1 for y in iris.target])

# 隨機取90%的資料為訓練集,剩下的為測試集
train_indices = np.random.choice(len(x_vals), round(len(x_vals)*0.9), replace=False)
test_indices = np.array(list(set(range(len(x_vals))) - set(train_indices)))
x_vals_train = x_vals[train_indices]
x_vals_test  = x_vals[test_indices]
y_vals_train = y_vals[train_indices]
y_vals_test  = y_vals[test_indices]

batch_size = 100
trian_times = 500
learning_rate = 0.01

# placeholder 就是裝訓練時放資料的容器
x_data = tf.placeholder(dtype=tf.float32, shape=[None, 2])
y_target = tf.placeholder(dtype=tf.float32, shape=[None, 1])

# 變數Variable 是要優化的變數,要是不能變肯定不能優化了
A = tf.Variable(tf.random_normal(shape=[2,1]))
b = tf.Variable(tf.random_normal(shape=[1,1]))

module_output = tf.subtract(tf.matmul(x_data, A), b)

# 損失函式,此處用的是:看下文的損失函式
l2_norm = tf.reduce_sum(tf.square(A))
alpha = tf.constant([0.01])
classification_term = tf.reduce_mean(tf.maximum(0., tf.subtract(1., tf.multiply(module_output, y_target))))
loss = tf.add(classification_term, tf.multiply(alpha, l2_norm))
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)

# 計算精度
prediction = tf.sign(module_output)
accuracy = tf.reduce_mean(tf.cast(tf.equal(prediction, y_target), tf.float32))

init = tf.global_variables_initializer()
sess.run(init)

loss_vec = []
train_accuracy = []
test_accuracy = []

# 訓練在 for 迴圈裡面
for i in range(trian_times):
	index_rand = np.random.choice(len(x_vals_test), size=batch_size)
	x_rand = x_vals_train[index_rand]
	y_rand = np.transpose([y_vals_train[index_rand]])
	sess.run(optimizer, feed_dict={x_data:x_rand, y_target:y_rand})
	loss_vec.append(sess.run(loss, feed_dict={x_data:x_rand, y_target:y_rand}))
	train_accuracy.append(sess.run(accuracy, feed_dict={x_data:x_vals_train, y_target:np.transpose([y_vals_train])}))
	test_accuracy.append(sess.run(accuracy, feed_dict={x_data:x_vals_test, y_target:np.transpose([y_vals_test])}))

[[a1], [a2]] = sess.run(A)
[[bb]] = sess.run(b)

# 準備畫分割線
slope = -a2/a1
intercept = bb/a1
x_line = [x[1] for x in x_vals]
y_line = [slope*x+intercept for x in x_line]

# 準備畫資料點
setosa_x = [d[1] for i,d in enumerate(x_vals) if y_vals[i]==1]
setosa_y = [d[0] for i,d in enumerate(x_vals) if y_vals[i]==1]
not_setosa_x = [d[1] for i,d in enumerate(x_vals) if y_vals[i]==-1]
not_setosa_y = [d[0] for i,d in enumerate(x_vals) if y_vals[i]==-1]

# 畫資料點
plt.subplot(221)
plt.plot(setosa_x, setosa_y, 'ro', label='Is setosa')
plt.plot(not_setosa_x, not_setosa_y, 'g*', label='Non-setosa')
plt.plot(x_line, y_line, 'b-', label='Linear Seperator')
plt.ylim([2, 10])
plt.title('Sepal Length vs Pedal Width')
plt.xlabel('Pedal Width')
plt.ylabel('Sepal Length')
plt.legend(loc='lower right')

# 畫精度變化曲線
plt.subplot(222)
plt.plot(train_accuracy, 'g-', label='Training Accuracy')
plt.plot(test_accuracy, 'r--', label='Test Accuracy')
plt.title('Train and Test Set Accuracies')
plt.xlabel('Generation')
plt.ylabel('Accuracy')
plt.legend(loc='lower right')

# 畫 Loss 曲線
plt.subplot(223)
plt.plot(loss_vec, 'k-')
plt.title('Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('Loss')

plt.show()

討論

上面有一個隨機樹的設定,制定了產生隨機數的種子(seed),如果你執行我的程式碼,雖然中間有產生隨機數的過程,但產生的所有隨機數都我執行時產生的隨機數一樣,最終得到的分割線的斜率等也和我的一模一樣。 但是,讓人迷惑的是,換用一些其他的隨機數種子,即把上面的第 10、11行寫成

np.random.seed(14)
tf.set_random_seed(222)

這是產生的圖如下 在這裡插入圖片描述 一看就知道分類效果很差。 換句話說,分類效果和隨機數有關,要是不初始化隨機數,可成產生很糟的結果。在剛開始我沒有寫第10、11行時,出現過不少完全沒有分為兩類的情況。我也不清楚什麼原因。歡迎留言討論。

參考書籍

Nick McClure. TensorFlow機器學習攻略(影印版)[M]. 東南大學出版社(南京).2017.10