TensorFlow技術解析與實戰 8 第一個tensorflow程式
TensorFlow的執行方式分如下4步:
(1)載入資料及定義超引數
(2)構建網路
(3)訓練模型
(4)評估模型和進行預測
8.2 超引數的設定# -*- coding: utf-8 -*- import sys reload(sys) sys.setdefaultencoding('utf-8') import tensorflow as tf import numpy as np # y = x^2 - 0.5 # 生成及載入資料 x_data = np.linspace(-1, 1, 300)[:, np.newaxis] #構建了300個點 noise = np.random.normal(0, 0.05, x_data.shape) #加入一些噪聲點 y_data = np.square(x_data) - 0.5 + noise xs = tf.placeholder(tf.float32, [None, 1]) ys = tf.placeholder(tf.float32, [None, 1]) # 構建網路模型 # y = weights*x + biases def add_layer(inputs, in_size, out_size, activation_function=None): weights = tf.Variable(tf.random_normal([in_size, out_size])) # in_size * out_size 大小的矩陣 biases = tf.Variable(tf.zeros([1, out_size]) + 0.1) # 1 X out_size 的矩陣 Wx_plus_b = tf.matmul(inputs, weights) + biases # 矩陣相乘 if activation_function is None: outputs = Wx_plus_b else: outputs = activation_function(Wx_plus_b) return outputs # 構建隱藏層,假設隱藏層有10個神經元 h1 = add_layer(xs, 1, 20, activation_function=tf.nn.relu) # 構建輸出層,假設輸出層和輸入層一樣,有1個神經元 prediction = add_layer(h1, 20, 1, activation_function=None) # 計算預測值和真實值間的誤差 loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction), reduction_indices=[1])) train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss) # 訓練模型 init = tf.global_variables_initializer() #初始化所有變數 sess = tf.Session() sess.run(init) for i in range(1000): sess.run(train_step, feed_dict={xs:x_data, ys:y_data}) if i % 50 == 0: print(sess.run(loss, feed_dict={xs:x_data, ys:y_data}))
所謂超引數就是指機器學習模型裡的框架引數。與權重引數不同的是,它是需要手動設定、不斷試錯的。
學習率是一個最常設定的超引數。學習率設定得越大,訓練時間越短,速度越快;而學習率設定得越小,訓練準確度越高。那麼,如何確定一個比較好的學習率呢?只能通過實驗的方法。例如,先設定0.01,觀察損失值的變化,然後嘗試0.001、0.0001,最終確定一個比較合適的學習率。
我們也可以設定可變的學習率。那麼,怎樣才算是準確率不再提高,應該停止訓練了呢?例如,在訓練過程中記錄最佳的準確率,在連續n輪沒有達到最佳的準確率時,便可以認為準確率不再提高,就可以停止訓練,稱為“early stopping”,這個策略叫做“no-improvement-in-n”規則(例如,我們設定連續10輪準確率不再變動,就認為不再提高)此時,讓學習率減半;下次滿足是,再讓學習率減半。這樣,在逐漸解決最優解時,我們的學習率越來越小,準確率就越來越高。
mini-batch大小是另一個最常設定的超引數。每批大小決定了權重的更新規則。例如,大小為32時,就是把32個樣本的梯度全部計算完,然後求平均值,去更新權重。批次越小訓練的速度就慢。那麼,如何選擇批次大小呢》也需要結合機器的硬體效能以及資料集的大小來設定。
正則項係數是另一個常用的超引數。但是,設定沒有太多可遵循的規則,一般憑經驗。一般來說,如果在較複雜的網路發現出現了明顯的過擬合(在訓練資料準確率很高但測試資料準確率反而下降),可以考慮增加此項。初學者可以一開始設定為0,然後確定好一個比較好的學習率後,再給lamda一個值,隨後根據準確率再進行精細調整。