1. 程式人生 > >pytorch基本

pytorch基本

pytorch主要分為以下幾個模組來訓練模型:

    tensor:tensor為基本結構,可以直接建立,從list建立以及由numpy陣列得到,torch還提供一套運算以及shape變換方式。
    Variable:自動求導機制,利用Variable包裝tensor後,便可以使用其求導的功能了,有點像個裝飾器。
    nn:nn模組是整個pytorch的核心,自己設計的Net(),繼承nn.Model後可以提取模型引數,進行前向forward()運算(自己設計),以及後向運算(自動),nn提供基本網路結構單元,例如nn.Linear(),nn.Conv2d()等,還提供基本損失函式nn.CrossEntropyLoss等。
    torch.optim:該模組提供自動求導更新引數等功能,用它封裝模型引數nn.parameter()後,loss求導後,可以用.step來更新整個引數。
    torch.utils.data.DataSet:該模組提供載入資料初始化的方式,完善好getitem和len的介面後,便可以利用DataLoader多程序載入資料。

 

參考:

https://blog.csdn.net/qq_16949707/article/details/79067474