1. 程式人生 > >TensorFlow 全連線網路實現

TensorFlow 全連線網路實現

1**神經網路**是一種數學模型,大量的神經元相連線並進行計算,用來對輸入和輸出間複雜的關係進行建模。

神經網路訓練,通過大量資料樣本,對比正確答案和模型輸出之間的區別(梯度),然後把這個區別(梯度)反向的傳遞回去,對每個相應的神經元進行一點點的改變。那麼下一次在訓練的時候就可以用已經改進一點點的神經元去得到稍微準確一點的結果。

基於TensorFlow實現一個簡單的神經網路。

結構圖

FullConnectedNetwork1.png

搭建神經網路圖

1. 準備訓練資料

匯入相應的包:

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as
plt

準備訓練資料:

x_data = np.linspace(-1, 1, 300, dtype=np.float32)[:, np.newaxis]
noise = np.random.normal(0, 0.05, x_data.shape).astype(np.float32)
y_data = 2 * np.power(x_data, 3) + np.power(x_data, 2) + noise

2. 定義網路結構

定義佔位符:

xs = tf.placeholder(tf.float32, [None, 1])
ys = tf.placeholder(tf.float32, [None
, 1])

定義神經層:隱藏層和預測層

# 隱層1
Weights1 = tf.Variable(tf.random_normal([1, 5]))
biases1 = tf.Variable(tf.zeros([1, 5]) + 0.1)
Wx_plus_b1 = tf.matmul(xs, Weights1) + biases1
l1 = tf.nn.relu(Wx_plus_b1)
# 隱層2
Weights2 = tf.Variable(tf.random_normal([5, 10]))
biases2 = tf.Variable(tf.zeros([1, 10]) + 0.1)
Wx_plus_b2 = tf.matmul(l1, Weights2) + biases2
l2 = tf.nn.relu(Wx_plus_b2)
# 輸出層
Weights3 = tf.Variable(tf.random_normal([10, 1])) biases3 = tf.Variable(tf.zeros([1, 1]) + 0.1) prediction = tf.matmul(l2, Weights3) + biases3

3. 定義 loss 表示式

這裡採用均方差(mean squared error):

MAE(y,y^)=1nsamplesi=1n(yiyi^)2
loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction), reduction_indices=[1]))

4. optimizer

即訓練的優化策略,一般有梯度下降(GradientDescentOptimizer)、AdamOptimizer等。.minimize(loss)是讓 loss 達到最小。

train_step = tf.train.AdamOptimizer(0.1).minimize(loss)

訓練

# 初始化所有變數
init = tf.global_variables_initializer()
# 啟用會話
with tf.Session() as sess:
    sess.run(init)
    # 繪製原始x-y散點圖。
    fig = plt.figure()
    ax = fig.add_subplot(1, 1, 1)
    ax.scatter(x_data, y_data)
    plt.ion()
    plt.show()
    # 迭代次數 = 10000
    for i in range(10000):
        # 訓練
        sess.run(train_step, feed_dict={xs: x_data, ys: y_data})
        # 每50步繪圖並列印輸出。
        if i % 50 == 0:
            # 視覺化模型輸出的結果。
            try:
                ax.lines.remove(lines[0])
            except Exception:
                pass
            prediction_value = sess.run(prediction, feed_dict={xs: x_data})
            # 繪製模型預測值。
            lines = ax.plot(x_data, prediction_value, 'r-', lw=5)
            plt.pause(1)
            # 列印損失
            print(sess.run(loss, feed_dict={xs: x_data, ys: y_data}))

最終結果

最終損失:0.0026713(不同的初始化可能會有不同)

FullConnectedNetwork2.png

Reference