1. 程式人生 > >Pytorch常用函式

Pytorch常用函式

一、模型的儲存與載入
實現訓練過程中模型的儲存,以及在預訓練的基礎上繼續訓練模型
①儲存和載入整個模型

# 儲存和載入整個模型
torch.save(model_object, 'model.pkl')
model = torch.load('model.pkl')

②只儲存模型中的引數

# 僅儲存和載入模型引數(推薦使用)
torch.save(model_object.state_dict(), 'params.pkl')
model_object.load_state_dict(torch.load('params.pkl'))

二、關於計算準確率的常用函式
1.torch.mean()函式
函式原型:torch.mean(input, dim, out=None)
功能:返回輸入張量給定維度dim上每行的均值,輸出形狀與輸入相同,除了給定維度上為1.。
引數:
input (Tensor) – 輸入張量
dim (int) – the dimension to reduce
out (Tensor, optional) – 結果張量

>>> a = torch.randn(4, 4)
>>> a

-1.2738 -0.3058  0.1230 -1.9615
 0.8771 -0.5430 -0.9233  0.9879
 1.4107  0.0317 -0.6823  0.2255
-1.3854  0.4953 -0.2160  0.2435
[torch.FloatTensor of size 4x4]

>>> torch.mean(a, 1)

-0.8545
 0.0997
 0.2464
-0.2157
[torch.FloatTensor of size 4x1]

2.torch.eq()函式
比較兩向量是否,兩個向量的維度必須一致,如果相等,對應維度上的數為1,若果不相等則對應位置上的元素為0.

import torch
import torch.nn as nn

output = torch.FloatTensor([[0.91,0,0.9,0,0],
							[1.0,0,0.98,0,0.9],
							[0,0,0,1.0,0],
							[1.0,0,1.0,0,0],
							[0.9,1.0,0.98,0.89,0.60]])
_, pred = output.topk(1,1,True,True)
# print(pred)
pred = pred.t()#轉置
target = torch.FloatTensor([[0, 0, 2, 2, 1]])
# print(target.view(1, -1))
exp_t = target.view(1, -1).expand_as(pred)#b.expand_as(a)就是將b進行擴充,擴充到a的維度,需要說明的是a的低維度需要比b大,例如b的shape是31,如果a的shape是32不會出錯, correct = pred.eq(exp_t.long())#注意這裡的eq函式的引數必須為LongTensor型別,這裡利用了long()的方法將其轉化為LongTensor型別。 correct_k = correct[:1].view(-1).float().sum(0) top1_acc = correct_k.mul_(100.0 / 5) print("top1_acc:",top1_acc)

關於Tensor的型別梳理見該文章
關於Tensor的維度操作梳理見該文章