1. 程式人生 > >TensorFlow 深度學習框架 (2)-- 反向傳播優化神經網路

TensorFlow 深度學習框架 (2)-- 反向傳播優化神經網路

訓練神經網路的過程就是設定神經網路引數的過程,只有經過有效訓練的神經網路模型才可以真正的解決分類問題或迴歸問題。使用監督學習的方式設定神經網路引數需要有一個標註好的訓練資料集。監督學習的最重要的思想是,在已知答案的標註資料集上,模型給出的預測結果要儘可能接近真實的答案。

在神經網路優化演算法中,最常用的就是反向傳播演算法。反向傳播演算法的具體工作流程如圖


如圖所示是訓練的流程圖,那麼在訓練之前,還有一個問題有待解決,就是如何評估模型與標註資料的差距?答案是通過優化損失函式,損失函式就是代表訓練資料與標註資料差異的一個指標。針對分類問題,其中最常用的損失函式之一就是交叉熵。關於交叉熵放到後續再仔細討論,這裡知道這麼一個概念就可以了。

針對第(1)講述的那個神經網路結構,訓練程式碼示例和註釋 如下

import tensorflow as tf
from numpy.random import RandomState

#以下就是第(1)節中的前向傳播過程
batch_size = 8
w1 = tf.Variable(tf.random_normal([2,3], stddev = 1,seed = 1))
w2 = tf.Variable(tf.random_normal([3,1], stddev = 1,seed = 1))
x = tf.placeholder(tf.float32, shape = (None,2),name = "x-input"))  # shape = (None,2) 代表一個batch 的訓練資料
y_ = tf.placeholder(tf.float32,shape = (None,1),name = "y-input"))  # 一個batch 的標註資料

#前向傳播過程
a = tf.matmul(x,w1)
y = tf.matmul(a,w2)

#定義損失函式(交叉熵)和反向傳播演算法
cross_entropy = -tf.reduce_mean(
    y_ * tf.log(tf.clip_by_value(y,1e-10,1.0)))
train_step = tf.train.AdamOptimizer(0.001).minimize(cross_entropy)

#通過隨機數生成一個模擬訓練資料集及其標註
rdm = RandomState(1)
dataset_size = 128
X = rdm.rand(dataset_size,2)
#定義規則來給出樣本的標籤,在這裡所有 x1 + x2 < 1 的樣例被認為是正樣本
Y = [[int(x1 + x2 < 1)] for (x1,x2) in X]

#建立會話來執行 TensorFlow 程式
with tf.Session() as sess:
    init_op = tf.initialize_all_variables()
    sess.run(init_op) #變數初始化
    
    #設定訓練的輪數
    STEPS = 5000
    for i in range(STEPS):
        #每次選取 batch_size 個樣本進行訓練
        start = (i * batch_size) % dataset_size
        end = min(start + batch_size,dataset_size)

        #通過選取的樣本訓練神經網路並更新引數
        sess.run(train_step,feed_dict = {x:X[start:end],y_:Y[start:end]})
        #每 1000 輪觀測總體交叉熵的結果
        if i % 1000 == 0:
            total_cross_entropy = sess.run(cross_entropy,
                                            feed_dict = {x:X,y_:Y})
            print("After %d training steps,cross entropy on all data is %g" %(i,total_cross_entropy))
            """
            在執行的過程中,交叉熵越小表明訓練隨著輪數的增加,訓練的結果與真實的差距越來越小
            """
    
    #最後觀測訓練出來的引數值
    print sess.run(w1)
    print sess.run(w2)