1. 程式人生 > >TensorFlow HOWTO 1.3 邏輯迴歸

TensorFlow HOWTO 1.3 邏輯迴歸

1.3 邏輯迴歸

將線性迴歸的模型改一改,就可以用於二分類。邏輯迴歸擬合樣本屬於某個分類,也就是樣本為正樣本的概率。

操作步驟

匯入所需的包。

import tensorflow as tf
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import sklearn.datasets as ds
import sklearn.model_selection as ms

匯入資料,並進行預處理。我們使用鳶尾花資料集所有樣本,根據萼片長度和花瓣長度預測樣本是不是山鳶尾(第一種)。

iris = ds.load_iris()

x_ = iris.data[:, [0, 2]]
y_ = (iris.target == 0).astype(int)
y_ = np.expand_dims(y_ , 1)

x_train, x_test, y_train, y_test = \
    ms.train_test_split(x_, y_, train_size=0.7, test_size=0.3)

定義超引數。

變數 含義
n_input 樣本特徵數
n_epoch
迭代數
lr 學習率
threshold 如果輸出超過這個概率,將樣本判定為正樣本
n_input = 2
n_epoch = 2000
lr = 0.05
threshold = 0.5

搭建模型。

變數 含義
x 輸入
y 真實標籤
w 權重
b 偏置
z
中間變數,x的線性變換
a 輸出,也就是樣本是正樣本的概率
x = tf.placeholder(tf.float64, [None, n_input])
y = tf.placeholder(tf.float64, [None, 1])
w = tf.Variable(np.random.rand(n_input, 1))
b = tf.Variable(np.random.rand(1, 1))
z = x @ w + b
a = tf.sigmoid(z)

定義損失、優化操作、和準確率度量指標。分類問題有很多指標,這裡只展示一種。

我們使用交叉熵損失函式,如下。

m e a n ( Y log ( A ) + ( 1 Y ) log ( 1 A ) ) -mean(Y \otimes \log(A) + (1-Y) \otimes \log(1-A))

它的意思是,對於正樣本,y 為 1,損失變為-log(a),輸出會盡可能接近一。對於負樣本,y為 0,損失變為-log(1 - a),輸出會盡可能接近零。總之,它使輸出儘可能接近真實標籤。

變數 含義
loss 損失
op 優化操作
y_hat 標籤的預測值
acc 準確率
loss = - tf.reduce_mean(y * tf.log(a) + (1 - y) * tf.log(1 - a))
op = tf.train.AdamOptimizer(lr).minimize(loss)

y_hat = tf.to_double(a > threshold)
acc = tf.reduce_mean(tf.to_double(tf.equal(y_hat, y)))

使用訓練集訓練模型。

losses = []
accs = []

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver(max_to_keep=1)
    
    for e in range(n_epoch):
        _, loss_ = sess.run([op, loss], feed_dict={x: x_train, y: y_train})
        losses.append(loss_)

使用測試集計算準確率。

        acc_ = sess.run(acc, feed_dict={x: x_test, y: y_test})
        accs.append(acc_)

每一百步列印損失和度量值。

        if e % 100 == 0:
            print(f'epoch: {e}, loss: {loss_}, acc: {acc_}')
            saver.save(sess,'logit/logit', global_step=e)

得到決策邊界:

    x_plt = x_[:, 0]
    y_plt = x_[:, 1]
    c_plt = y_.ravel()
    x_min = x_plt.min() - 1
    x_max = x_plt.max() + 1
    y_min = y_plt.min() - 1
    y_max = y_plt.max() + 1
    x_rng = np.arange(x_min, x_max, 0.05)
    y_rng = np.arange(y_min, y_max, 0.05)
    x_rng, y_rng = np.meshgrid(x_rng, y_rng)
    model_input = np.asarray([x_rng.ravel(), y_rng.ravel()]).T
    model_output = sess.run(y_hat, feed_dict={x: model_input}).astype(int)
    c_rng = model_output.reshape(x_rng.shape)

輸出:

epoch: 0, loss: 3.935746371309244, acc: 0.3333333333333333
epoch: 100, loss: 0.1969325408656252, acc: 1.0
epoch: 200, loss: 0.08548362243852041, acc: 1.0
epoch: 300, loss: 0.050833687966014396, acc: 1.0
epoch: 400, loss: 0.034929315249291375, acc: 1.0
epoch: 500, loss: 0.026013692651528184, acc: 1.0
epoch: 600, loss: 0.02038864243607467, acc: 1.0
epoch: 700, loss: 0.016552042129938136, acc: 1.0
epoch: 800, loss: 0.013786692432697542, acc: 1.0
epoch: 900, loss: 0.011709709551073783, acc: 1.0
epoch: 1000, loss: 0.010099234422592073, acc: 1.0
epoch: 1100, loss: 0.008818382202721829, acc: 1.0
epoch: 1200, loss: 0.007778392815694136, acc: 1.0
epoch: 1300, loss: 0.0069193419951217704, acc: 1.0
epoch: 1400, loss: 0.0061993983430654875, acc: 1.0
epoch: 1500, loss: 0.00558852696047961, acc: 1.0
epoch: 1600, loss: 0.005064638072189167, acc: 1.0
epoch: 1700, loss: 0.00461114435393481, acc: 1.0
epoch: 1800, loss: 0.004215362417896155, acc: 1.0
epoch: 1900, loss: 0.003867437954560204, acc: 1.0

繪製整個資料集以及決策邊界。

plt.figure()
cmap = mpl.colors.ListedColormap(['r', 'b'])
plt.scatter(x_plt, y_plt, c=c_plt, cmap=cmap)
plt.contourf(x_rng, y_rng, c_rng, alpha=0.2, linewidth=5, cmap=cmap)
plt.title('Data and Model')
plt.xlabel('Petal Length (cm)')
plt.ylabel('Sepal Length (cm)')
plt.show()

繪製訓練集上的損失。

plt.figure()
plt.plot(losses)
plt.title('Loss on Training Set')
plt.xlabel('#epoch')
plt.ylabel('Cross Entropy')
plt.show()

繪製測試集上的準確率。

plt.figure()
plt.plot(accs)
plt.title('Accurary on Testing Set')
plt.xlabel('#epoch')
plt.ylabel('Accurary')
plt.show()

擴充套件閱讀