【python學習筆記】13:用梯度下降法求解最優值問題
阿新 • • 發佈:2019-01-02
梯度是函式在某點沿每個座標的偏導數構成的向量,它反映了函式沿著哪個方向增加得最快。因此要求解一個二元函式的極小值,只要沿著梯度的反方向走,直到函式值的變化滿足精度即可。
這裡打表儲存了途徑的每個點,最後在圖上繪製出來以反映路徑。
*梯度下降的具體實現
import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D import numpy as np def Fun(x,y):#原函式 return x-y+2*x*x+2*x*y+y*y def PxFun(x,y):#偏x導 return 1+4*x+2*y def PyFun(x,y):#偏y導 return -1+2*x+2*y #初始化 fig=plt.figure()#figure物件 ax=Axes3D(fig)#Axes3D物件 X,Y=np.mgrid[-2:2:40j,-2:2:40j]#取樣並作滿射聯合 Z=Fun(X,Y)#取樣點Z座標打表 ax.plot_surface(X,Y,Z,rstride=1,cstride=1,cmap="rainbow") ax.set_xlabel('x') ax.set_ylabel('y') ax.set_zlabel('z') #梯度下降 step=0.0008#下降係數 x=0 y=0#初始選取一個點 tag_x=[x] tag_y=[y] tag_z=[Fun(x,y)]#三個座標分別打入表中,該表用於繪製點 new_x=x new_y=y Over=False while Over==False: new_x-=step*PxFun(x,y) new_y-=step*PyFun(x,y)#分別作梯度下降 if Fun(x,y)-Fun(new_x,new_y)<7e-9:#精度 Over=True x=new_x y=new_y#更新舊點 tag_x.append(x) tag_y.append(y) tag_z.append(Fun(x,y))#新點三個座標打入表中 #繪製點/輸出座標 ax.plot(tag_x,tag_y,tag_z,'r.') plt.title('(x,y)~('+str(x)+","+str(y)+')') plt.show()
執行結果:
換個角度看:
驗證結果是否正確: