1. 程式人生 > >pytorch-神經網路擬合曲線

pytorch-神經網路擬合曲線

程式碼已經調通,跑出來的效果如下:

# coding=gbk
import torch
import matplotlib.pyplot as plt
from torch.autograd import Variable
import torch.nn.functional as F

'''
    Pytorch是一個擁有強力GPU加速的張量和動態構建網路的庫,其主要構建是張量,所以可以把PyTorch當做Numpy
    來用,Pytorch的很多操作好比Numpy都是類似的,但是其能夠在GPU上執行,所以有著比Numpy快很多倍的速度。

    訓練完了,發現隱層越大,擬合的速度越是快,擬合的效果越是好
'''

def train():
    print('------      構建資料集      ------')
    # torch.linspace是為了生成連續間斷的資料,第一個引數表示起點,第二個引數表示終點,第三個引數表示將這個區間分成平均幾份,即生成幾個資料
    x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
    #torch.rand返回的是[0,1]之間的均勻分佈   這裡是使用一個計算式子來構造出一個關聯結果,當然後期要學的也就是這個式子
    y = x.pow(2) + 0.2 * torch.rand(x.size())
    # Variable是將tensor封裝了下,用於自動求導使用
    x, y = Variable(x), Variable(y)
    #繪圖展示
    plt.scatter(x.data.numpy(), y.data.numpy())
    #plt.show()

    print('------      搭建網路      ------')
    #使用固定的方式繼承並重寫 init和forword兩個類
    class Net(torch.nn.Module):
        def __init__(self,n_feature,n_hidden,n_output):
            #初始網路的內部結構
            super(Net,self).__init__()
            self.hidden=torch.nn.Linear(n_feature,n_hidden)
            self.predict=torch.nn.Linear(n_hidden,n_output)
        def forward(self, x):
            #一次正向行走過程
            x=F.relu(self.hidden(x))
            x=self.predict(x)
            return x
    net=Net(n_feature=1,n_hidden=1000,n_output=1)
    print('網路結構為:',net)

    print('------      啟動訓練      ------')
    loss_func=F.mse_loss
    optimizer=torch.optim.SGD(net.parameters(),lr=0.001)

    #使用資料 進行正向訓練,並對Variable變數進行反向梯度傳播  啟動100次訓練
    for t in range(10000):
        #使用全量資料 進行正向行走
        prediction=net(x)
        loss=loss_func(prediction,y)
        optimizer.zero_grad()  #清除上一梯度
        loss.backward() #反向傳播計算梯度
        optimizer.step()  #應用梯度

        #間隔一段,對訓練過程進行視覺化展示
        if t%5==0:
            plt.cla()
            plt.scatter(x.data.numpy(),y.data.numpy()) #繪製真是曲線
            plt.plot(x.data.numpy(),prediction.data.numpy(),'r-',lw=5)
            plt.text(0.5,0,'Loss='+str(loss.data[0]),fontdict={'size':20,'color':'red'})
            plt.pause(0.1)
    plt.ioff()
    plt.show()
    print('------      預測和視覺化      ------')

if __name__=='__main__':
    train()