1. 程式人生 > >tensorflow-梯度下降,有這一篇就夠了(深度好文)

tensorflow-梯度下降,有這一篇就夠了(深度好文)

前言

最近機器學習越來越火了,前段時間斯丹福大學副教授吳恩達都親自錄製了關於Deep Learning Specialization的教程,在國內掀起了巨大的學習熱潮。本著不被時代拋棄的念頭,自己也開始研究有關機器學習的知識。都說機器學習的學習難度非常大,但不親自嘗試一下又怎麼會知道其中的奧妙與樂趣呢?只有不斷的嘗試才能找到最適合自己的道路。

請容忍我上述的自我煽情,下面進入主題。這篇文章主要對機器學習中所遇到的GradientDescent(梯度下降)進行全面分析,相信你看了這篇文章之後,對GradientDescent將徹底弄明白其中的原理。

梯度下降的概念

梯度下降法是一個一階最優化演算法,通常也稱為最速下降法。要使用梯度下降法找到一個函式的區域性極小值,必須向函式上當前點對於梯度(或者是近似梯度)的反方向的規定步長距離點進行迭代搜尋。所以梯度下降法可以幫助我們求解某個函式的極小值或者最小值。對於n維問題就最優解,梯度下降法是最常用的方法之一。下面通過梯度下降法的前生今世

來進行詳細推導說明。

梯度下降法的前世

首先從簡單的開始,看下面的一維函式:

f(x) = x^3 + 2 * x - 3

在數學中如果我們要求f(x) = 0處的解,我們可以通過如下誤差等式來求得:

error = (f(x) - 0)^2

error趨近於最小值時,也就是f(x) = 0x的解,我們也可以通過圖來觀察:

通過這函式圖,我們可以非常直觀的發現,要想求得該函式的最小值,只要將x指定為函式圖的最低谷。這在高中我們就已經掌握了該函式的最小值解法。我們可以通過對該函式進行求導(即斜率):

derivative(x) = 6 * x^5 + 16 * x^3 - 18 * x^2 + 8 * x - 12

如果要得到最小值,只需令derivative(x) = 0,即x = 1。同時我們結合圖與導函式可以知道:

  • x < 1時,derivative < 0,斜率為負的;

  • x > 1時,derivative > 0,斜率為正的;

  • x 無限接近 1時,derivative也就無限=0,斜率為零。

通過上面的結論,我們可以使用如下表達式來代替x在函式中的移動

x = x - reate * derivative

當斜率為負的時候,x增大,當斜率為正的時候,x減小;因此x總是會向著低谷移動,使得error最小,從而求得 f(x) = 0處的解。其中的rate代表x逆著導數方向移動的距離,rate越大,x

每次就移動的越多。反之移動的越少。

這是針對簡單的函式,我們可以非常直觀的求得它的導函式。為了應對複雜的函式,我們可以通過使用求導函式的定義來表達導函式:若函式f(x)在點x0處可導,那麼有如下定義:

上面是都是公式推導,下面通過程式碼來實現,下面的程式碼都是使用python進行實現。

>>> def f(x):
...     return x**3 + 2 * x - 3
...
>>> def error(x):
...     return (f(x) - 0)**2
...
>>> def gradient_descent(x):
...     delta = 0.00000001
...     derivative = (error(x + delta) - error(x)) / delta
...     rate = 0.01
...     return x - rate * derivative
...
>>> x = 0.8
>>> for i in range(50):
...     x = gradient_descent(x)
...     print('x = {:6f}, f(x) = {:6f}'.format(x, f(x)))
...

執行上面程式,我們就能得到如下結果:

x = 0.869619, f(x) = -0.603123
x = 0.921110, f(x) = -0.376268
x = 0.955316, f(x) = -0.217521
x = 0.975927, f(x) = -0.118638
x = 0.987453, f(x) = -0.062266
x = 0.993586, f(x) = -0.031946
x = 0.996756, f(x) = -0.016187
x = 0.998369, f(x) = -0.008149
x = 0.999182, f(x) = -0.004088
x = 0.999590, f(x) = -0.002048
x = 0.999795, f(x) = -0.001025
x = 0.999897, f(x) = -0.000513
x = 0.999949, f(x) = -0.000256
x = 0.999974, f(x) = -0.000128
x = 0.999987, f(x) = -0.000064
x = 0.999994, f(x) = -0.000032
x = 0.999997, f(x) = -0.000016
x = 0.999998, f(x) = -0.000008
x = 0.999999, f(x) = -0.000004
x = 1.000000, f(x) = -0.000002
x = 1.000000, f(x) = -0.000001
x = 1.000000, f(x) = -0.000001
x = 1.000000, f(x) = -0.000000
x = 1.000000, f(x) = -0.000000
x = 1.000000, f(x) = -0.000000

通過上面的結果,也驗證了我們最初的結論。x = 1時,f(x) = 0
所以通過該方法,只要步數足夠多,就能得到非常精確的值。

梯度下降法的今生

上面是對一維函式進行求解,那麼對於多維函式又要如何求呢?我們接著看下面的函式,你會發現對於多維函式也是那麼的簡單。

f(x) = x[0] + 2 * x[1] + 4

同樣的如果我們要求f(x) = 0處,x[0]x[1]的值,也可以通過求error函式的最小值來間接求f(x)的解。跟一維函式唯一不同的是,要分別對x[0]x[1]進行求導。在數學上叫做偏導數

  • 保持x[1]不變,對x[0]進行求導,即f(x)x[0]的偏導數

  • 保持x[0]不變,對x[1]進行求導,即f(x)x[1]的偏導數

有了上面的理解基礎,我們定義的gradient_descent如下:

>>> def gradient_descent(x):
...     delta = 0.00000001
...     derivative_x0 = (error([x[0] + delta, x[1]]) - error([x[0], x[1]])) / delta
...     derivative_x1 = (error([x[0], x[1] + delta]) - error([x[0], x[1]])) / delta
...     rate = 0.01
...     x[0] = x[0] - rate * derivative_x0
...     x[1] = x[1] - rate * derivative_x1
...     return [x[0], x[1]]
...

rate的作用不變,唯一的區別就是分別獲取最新的x[0]x[1]。下面是整個程式碼:

>>> def f(x):
...     return x[0] + 2 * x[1] + 4
...
>>> def error(x):
...     return (f(x) - 0)**2
...
>>> def gradient_descent(x):
...     delta = 0.00000001
...     derivative_x0 = (error([x[0] + delta, x[1]]) - error([x[0], x[1]])) / delta
...     derivative_x1 = (error([x[0], x[1] + delta]) - error([x[0], x[1]])) / delta
...     rate = 0.02
...     x[0] = x[0] - rate * derivative_x0
...     x[1] = x[1] - rate * derivative_x1
...     return [x[0], x[1]]
...
>>> x = [-0.5, -1.0]
>>> for i in range(100):
...     x = gradient_descent(x)
...     print('x = {:6f},{:6f}, f(x) = {:6f}'.format(x[0],x[1],f(x)))
...

輸出結果為:

x = -0.560000,-1.120000, f(x) = 1.200000
x = -0.608000,-1.216000, f(x) = 0.960000
x = -0.646400,-1.292800, f(x) = 0.768000
x = -0.677120,-1.354240, f(x) = 0.614400
x = -0.701696,-1.403392, f(x) = 0.491520
x = -0.721357,-1.442714, f(x) = 0.393216
x = -0.737085,-1.474171, f(x) = 0.314573
x = -0.749668,-1.499337, f(x) = 0.251658
x = -0.759735,-1.519469, f(x) = 0.201327
x = -0.767788,-1.535575, f(x) = 0.161061
x = -0.774230,-1.548460, f(x) = 0.128849
x = -0.779384,-1.558768, f(x) = 0.103079
x = -0.783507,-1.567015, f(x) = 0.082463
x = -0.786806,-1.573612, f(x) = 0.065971
x = -0.789445,-1.578889, f(x) = 0.052777
x = -0.791556,-1.583112, f(x) = 0.042221
x = -0.793245,-1.586489, f(x) = 0.033777
x = -0.794596,-1.589191, f(x) = 0.027022
x = -0.795677,-1.591353, f(x) = 0.021617
x = -0.796541,-1.593082, f(x) = 0.017294
x = -0.797233,-1.594466, f(x) = 0.013835
x = -0.797786,-1.595573, f(x) = 0.011068
x = -0.798229,-1.596458, f(x) = 0.008854
x = -0.798583,-1.597167, f(x) = 0.007084
x = -0.798867,-1.597733, f(x) = 0.005667
x = -0.799093,-1.598187, f(x) = 0.004533
x = -0.799275,-1.598549, f(x) = 0.003627
x = -0.799420,-1.598839, f(x) = 0.002901
x = -0.799536,-1.599072, f(x) = 0.002321
x = -0.799629,-1.599257, f(x) = 0.001857
x = -0.799703,-1.599406, f(x) = 0.001486
x = -0.799762,-1.599525, f(x) = 0.001188
x = -0.799810,-1.599620, f(x) = 0.000951
x = -0.799848,-1.599696, f(x) = 0.000761
x = -0.799878,-1.599757, f(x) = 0.000608
x = -0.799903,-1.599805, f(x) = 0.000487
x = -0.799922,-1.599844, f(x) = 0.000389
x = -0.799938,-1.599875, f(x) = 0.000312
x = -0.799950,-1.599900, f(x) = 0.000249
x = -0.799960,-1.599920, f(x) = 0.000199
x = -0.799968,-1.599936, f(x) = 0.000159
x = -0.799974,-1.599949, f(x) = 0.000128
x = -0.799980,-1.599959, f(x) = 0.000102
x = -0.799984,-1.599967, f(x) = 0.000082
x = -0.799987,-1.599974, f(x) = 0.000065
x = -0.799990,-1.599979, f(x) = 0.000052
x = -0.799992,-1.599983, f(x) = 0.000042
x = -0.799993,-1.599987, f(x) = 0.000033
x = -0.799995,-1.599989, f(x) = 0.000027
x = -0.799996,-1.599991, f(x) = 0.000021
x = -0.799997,-1.599993, f(x) = 0.000017
x = -0.799997,-1.599995, f(x) = 0.000014
x = -0.799998,-1.599996, f(x) = 0.000011
x = -0.799998,-1.599997, f(x) = 0.000009
x = -0.799999,-1.599997, f(x) = 0.000007
x = -0.799999,-1.599998, f(x) = 0.000006
x = -0.799999,-1.599998, f(x) = 0.000004
x = -0.799999,-1.599999, f(x) = 0.000004
x = -0.799999,-1.599999, f(x) = 0.000003
x = -0.800000,-1.599999, f(x) = 0.000002
x = -0.800000,-1.599999, f(x) = 0.000002
x = -0.800000,-1.599999, f(x) = 0.000001
x = -0.800000,-1.600000, f(x) = 0.000001
x = -0.800000,-1.600000, f(x) = 0.000001
x = -0.800000,-1.600000, f(x) = 0.000001
x = -0.800000,-1.600000, f(x) = 0.000001
x = -0.800000,-1.600000, f(x) = 0.000000

細心的你可能會發現,f(x) = 0不止這一個解還可以是x = -2, -1。這是因為梯度下降法只是對當前所處的凹谷進行梯度下降求解,對於error函式並不代表只有一個f(x) = 0的凹谷。所以梯度下降法只能求得區域性解,但不一定能求得全部的解。當然如果對於非常複雜的函式,能夠求得區域性解也是非常不錯的。

tensorflow中的運用

通過上面的示例,相信對梯度下降也有了一個基本的認識。現在我們回到最開始的地方,在tensorflow中使用gradientDescent

import tensorflow as tf
 
# Model parameters
W = tf.Variable([.3], dtype=tf.float32)
b = tf.Variable([-.3], dtype=tf.float32)
# Model input and output
x = tf.placeholder(tf.float32)
linear_model = W*x + b
y = tf.placeholder(tf.float32)
 
# loss
loss = tf.reduce_sum(tf.square(linear_model - y)) # sum of the squares
# optimizer
optimizer = tf.train.GradientDescentOptimizer(0.01)
train = optimizer.minimize(loss)
 
# training data
x_train = [1, 2, 3, 4]
y_train = [0, -1, -2, -3]
# training loop
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init) # reset values to wrong
for i in range(1000):
  sess.run(train, {x: x_train, y: y_train})
 
# evaluate training accuracy
curr_W, curr_b, curr_loss = sess.run([W, b, loss], {x: x_train, y: y_train})
print("W: %s b: %s loss: %s"%(curr_W, curr_b, curr_loss))

上面的是tensorflow的官網示例,上面程式碼定義了函式linear_model = W * x + b,其中的error函式為linear_model - y。目的是對一組x_trainy_train進行簡單的訓練求解Wb。為了求得這一組資料的最優解,將每一組的error相加從而得到loss,最後再對loss進行梯度下降求解最優值。

optimizer = tf.train.GradientDescentOptimizer(0.01)
train = optimizer.minimize(loss)

在這裡rate0.01,因為這個示例也是多維函式,所以也要用到偏導數來進行逐步向最優解靠近。

for i in range(1000):
  sess.run(train, {x: x_train, y: y_train})
   

最後使用梯度下降進行迴圈推導,下面給出一些推導過程中的相關結果

W: [-0.21999997] b: [-0.456] loss: 4.01814
W: [-0.39679998] b: [-0.49552] loss: 1.81987
W: [-0.45961601] b: [-0.4965184] loss: 1.54482
W: [-0.48454273] b: [-0.48487374] loss: 1.48251
W: [-0.49684232] b: [-0.46917531] loss: 1.4444
W: [-0.50490189] b: [-0.45227283] loss: 1.4097
W: [-0.5115062] b: [-0.43511063] loss: 1.3761
....
....
....
W: [-0.99999678] 
            
           

相關推薦

tensorflow-梯度下降(深度)

前言最近機器學習越來越火了,前段時間斯丹福大學副教授吳恩達都親自錄製了關於Deep Learning Specialization的教程,在國內掀起了巨大的學習熱潮。本著不被時代拋棄的念頭,自己也開始研究有關機器學習的知識。都說機器學習的學習難度非常大,但不親自嘗試一下又怎麼

Android:RecyclerView 的使用

謹以文章記錄學習歷程,如有錯誤還請指明。 RecyclerView 簡介 首先,可以理解 RecyclerView 是 ListView 的升級版,更加靈活,同時由於封裝了 ListView 的部分實現,導致其使用更簡單,結構更清晰。 從名字

Azure IOT 設備固件更新技巧

trigger 物聯網平臺 搭建 href ice 有效 面板 調用 創建 嫌長不看版 今天為大家準備的硬菜是:在 Azure IoT 中心創建 Node.js 控制臺應用,進行端到端模擬固件更新,為基於 Intel Edison 的設備安裝新版固件的流程。通過創建模擬設備

想做好PPT折線圖

12月 image 菊花 -c 強調 spa any border 線圖 配圖主題無關今天鄭少跟大家聊聊折線圖的使用方法,或者你有疑問,折線圖很簡單,插入修改數據不就好了嗎?如果你要是這樣想的,恭喜你,有可能你會做出下面這樣的效果。如果你要是稍微懂一點折線圖的使用方法,你就

Linux 問題故障定位

1. 背景 有時候會遇到一些疑難雜症,並且監控外掛並不能一眼立馬發現問題的根源。這時候就需要登入伺服器進一步深入分析問題的根源。那麼分析問題需要有一定的技術經驗積累,並且有些問題涉及到的領域非常廣,才能定位到問題。所以,分析問題和踩坑是非常鍛鍊一個人的成長和提升自我能力。如果我們有一套好的分析工具,那將是事

C語言從入門到精通

影響 內容 當前 位置 replace 雙精度 下標 寄存器變量 一個 No.1 計算機與程序設計語言的關系 計算機系統由硬件系統和軟件系統構成,硬件相當於人類的肉體,而軟件相當於人類的靈魂,如果脫離了靈魂,人類就是一具行屍走肉 No.2 C語言的特點 代碼簡潔,靈活性高

【MYSQL學習筆記02】MySQL的高階應用之Explain(完美詳細版

版權宣告:本文為博主原創文章,未經博主允許不得轉載。 https://blog.csdn.net/wx1528159409 最近學習MySQL的高階應用Explain,寫一篇學習心得與總結,目錄腦圖如下: 一、Explain基本概念 1. Explain定義 · 我們知道M

抖音內容運營全解剖 !

抖音的火爆已經不用多說,作為短視訊的頭部APP,抖音已經從微信手中奪走不少使用者時間,成為新的“時間黑洞”。 比如:“中毒了,我每天晚上要刷2個小時”,“下一站,逃離微信,上抖音”… 一個企業運營抖音的目的是什麼? 答案顯而易見,無非就是做品牌營銷、擴大品牌影響力。 在短視訊領域積累

百萬併發下的Nginx優化

本文作者主要分享在 Nginx 效能方面的實踐經驗,希望能給大家帶來一些系統化思考,幫助大家更有效地去做 Nginx。 優化方法論 我重點分享如下兩個問題: 保持併發連線數,怎麼樣做到記憶體有效使用。 在高併發的同時保持高吞吐量的重要要點。 實現層面主要是三方面優化,主要聚焦

理解Sharding jdbc原理

相比於Spring基於AbstractRoutingDataSource實現的分庫分表功能,Sharding jdbc在單庫單表擴充套件到多庫多表時,相容性方面表現的更好一點。例如,spring實現的分庫分表sql寫法如下: select id, name, price,

產品設計教程:如何理解 px,dp,dpi, pt

先聊聊熟悉的幾個單位 圍繞著各種螢幕做設計和開發的人會碰到下面幾個單位:in, pt, px, dpi,dip/dp, sp 下面先簡單回顧下前四個單位: "in" inches的縮寫,英寸。就是螢幕的物理長度單位。一英寸等於2.54cm。比如Android手機

中後臺產品的表格設計(原型規範下載)

中後臺產品的表格設計,看這一篇就夠了(原型規範下載) 2018年4月16日luodonggan 中後臺產品的表格設計,看這一篇就夠了(原型規範下載) 經過了將近一年的後臺產品經歷,踩了很多坑,試了很多錯,也學習到了很多東西,目前也形成了自己的一套規範。本文將其中的部分收穫彙總成文,

Linux 常用指令 —— 摘自《Linux Probe》

touch:用於建立空白檔案或設定檔案的時間,ps:黑客可以用touch指令來修改檔案的最後修改時間,以隱藏自己的修改行為。 mkdir:用於建立空白的目錄,如mkdir path,可以結合引數-p來遞迴建立檔案目錄,如mkdir -p a/b/c/d/e cp:用於複製檔案或目錄,如cp 1.txt p

樹狀陣列(Binary Indexed Tree)

定義 根據維基百科的定義: A Fenwick tree or binary indexed tree is a data structure that can efficiently update elements and calculate pr

Cookie介紹及在Android中的使用總結超詳細

Cookie介紹 cookie的起源 早期Web開發面臨的最大問題之一是如何管理狀態。簡言之,伺服器端沒有辦法知道兩個請求是否來自於同一個瀏覽器。那時的辦法是在請求的頁面中插入一個token,並且在下一次請求中將這個token返回(至伺服器)。這就需要在form中插入一個包含toke

關於Kaggle入門

這次醞釀了很久想給大家講一些關於Kaggle那點兒事,幫助對資料科學(Data Science)有興趣的同學們更好的瞭解這個專案,最好能親身參與進來,體會一下學校所學的東西和想要解決一個實際的問題所需要的能力的差距。雖然不是Data Science出身,但本著嚴謹的科研態

並查集(Union-Find Algorithm)

動態連線(Dynamic connectivity)的問題 所謂的動態連線問題是指在一組可能相互連線也可能相互沒有連線的物件中,判斷給定的兩個物件是否聯通的一類問題。這類問題可以有如下抽象: 有一組構成不相交集合的物件 union: 聯通兩個物件

Android 必須知道2018年流行的框架庫及開發語言

導語2017 已經悄悄的走了,2018 也已經匆匆的來了,我們在總結過去的同時,也要展望一下未來,來規劃一下今年要學哪些新技術。這幾年優秀Android的開源庫不斷推出,新技術層出不窮,需要我們不斷去了解和掌握,在提高自身開發水平的同時,我們需要付出更多學習精力和時間。俗話說

Android 必須知道2018年流行的框架庫及開發語言

本文更新時間:2018年07月12日15:50:40導語    2017 已經悄悄的走了,2018 也已經匆匆的來了,我們在總結過去的同時,也要展望一下未來,來規劃一下今年要學哪些新技術。這幾年優秀Android的開源庫不斷推出,新技術層出不窮,需要我們不斷去了解和掌握,在提

關於使用format()方法格式化字串

從Python 2.6開始,又出現了另外一種格式化字串的方法——format()方法。format()方法是字串眾多方法中的一個,呼叫這個方法時要使用點操作符(.),該方法返回一個格式化好的字串。其呼叫格式如下:   s.format(……)   其中,s是一個待格式化的字串,裡面