1. 程式人生 > >【python學習筆記】13:用梯度下降法求解最優值問題

【python學習筆記】13:用梯度下降法求解最優值問題

梯度是函式在某點沿每個座標的偏導數構成的向量,它反映了函式沿著哪個方向增加得最快。因此要求解一個二元函式的極小值,只要沿著梯度的反方向走,直到函式值的變化滿足精度即可。

這裡打表儲存了途徑的每個點,最後在圖上繪製出來以反映路徑。

*梯度下降的具體實現

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np

def Fun(x,y):#原函式
    return x-y+2*x*x+2*x*y+y*y

def PxFun(x,y):#偏x導
    return 1+4*x+2*y

def PyFun(x,y):#偏y導
    return -1+2*x+2*y

#初始化
fig=plt.figure()#figure物件
ax=Axes3D(fig)#Axes3D物件
X,Y=np.mgrid[-2:2:40j,-2:2:40j]#取樣並作滿射聯合
Z=Fun(X,Y)#取樣點Z座標打表
ax.plot_surface(X,Y,Z,rstride=1,cstride=1,cmap="rainbow")
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('z')

#梯度下降
step=0.0008#下降係數
x=0
y=0#初始選取一個點
tag_x=[x]
tag_y=[y]
tag_z=[Fun(x,y)]#三個座標分別打入表中,該表用於繪製點
new_x=x
new_y=y
Over=False
while Over==False:
    new_x-=step*PxFun(x,y)
    new_y-=step*PyFun(x,y)#分別作梯度下降
    if Fun(x,y)-Fun(new_x,new_y)<7e-9:#精度
        Over=True
    x=new_x
    y=new_y#更新舊點
    tag_x.append(x)
    tag_y.append(y)
    tag_z.append(Fun(x,y))#新點三個座標打入表中

#繪製點/輸出座標
ax.plot(tag_x,tag_y,tag_z,'r.')
plt.title('(x,y)~('+str(x)+","+str(y)+')')
plt.show()


執行結果:

換個角度看:

驗證結果是否正確: