1. 程式人生 > >pytorch(四):神經網路的儲存與提取

pytorch(四):神經網路的儲存與提取

將神經網路訓練好之後,如何儲存它呢,儲存它之後有如何提取它呢?

如下圖所示,net1是訓練好的神經網路,有兩種方式儲存它:1.儲存整個訓練好的神經網路,2.儲存神經網路的最終引數

net2是根據第1種方式儲存的。net2是根據第2種方式儲存的

原始碼:

# 引入模組
import torch
import torch.nn.functional as f
from torch.autograd import Variable
import matplotlib.pyplot as plt


# 生成一些假資料
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)  # 神經網路只能接受二維資料的輸入
y = pow(x, 2) + 0.2*torch.rand(x.size())  # 後半部分製造噪音
x, y = Variable(x), Variable(y)  # 訓練神經網路時只能接受Variable形式輸入


# 定義儲存函式
def save():
    net1 = torch.nn.Sequential(
        torch.nn.Linear(1, 10),
        torch.nn.ReLU(),
        torch.nn.Linear(10, 1)
)
    optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)
    loss_func = torch.nn.MSELoss()
    for i in range(100):
        prediction = net1(x)  # 喂資料x給net1
        loss = loss_func(prediction, y)
        optimizer.zero_grad()  # 將上面運算過程中的grad清零
        loss.backward()  # 誤差反向傳遞
        optimizer.step()  # 將新引數作用於神經網路
    
    # 繪圖
    plt.figure(figsize=(10, 3))  # 設定影象的大小
    plt.subplot(131)
    plt.title('net1', color='red', size=20)
    plt.scatter(x.data.numpy(), y.data.numpy())
    plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
    plt.text(0.3, 0, 'loss=%.4f' % loss, fontdict={'color': 'red', 'size': 10})
    
    # 儲存net1的兩種方式
    torch.save(net1, 'net1.pkl')  # 方式1:儲存整個神經網路
    torch.save(net1.state_dict(), 'net1_parameters.pkl')  # 方式2:儲存神經網路的引數      


# 定義提取整個神經網路的函式
def restore_net():
    net2 = torch.load('net1.pkl')  # 載入檔案net1.pkl, 將其內容賦值給net2
    prediction = net2(x)
    loss_func = torch.nn.MSELoss()
    loss = loss_func(prediction, y)
    
    # 繪製net2結果圖形
    plt.subplot(132)
    plt.title('net2', color='red', size=20)
    plt.scatter(x.data.numpy(), y.data.numpy())
    plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
    plt.text(0.3, 0, 'loss=%.4f' % loss, fontdict={'color': 'red', 'size': 10})


# 定義提取神經網路狀態引數的函式
def restore_net_parameters():
    net3 = torch.nn.Sequential(
        torch.nn.Linear(1, 10),
        torch.nn.ReLU(),
        torch.nn.Linear(10, 1)
)  # 構造net3的基本框架
    net3.load_state_dict(torch.load('net1_parameters.pkl'))  # 提取net1的狀態引數,將狀態引數給net3
    prediction = net3(x)
    loss_func = torch.nn.MSELoss()
    loss = loss_func(prediction, y)
    
    # 繪製net3結果圖形
    plt.subplot(133)
    plt.title('net3', color='red', size=20)
    plt.scatter(x.data.numpy(), y.data.numpy())
    plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
    plt.text(0.3, 0, 'loss=%.4f' % loss, fontdict={'color': 'red', 'size': 10})


# 呼叫函式
save()
restore_net()
restore_net_parameters()
plt.show()  # 將三個函式繪製的圖形顯示出來

注意:將plt.show()放置在最後,能顯示出三幅影象連在一起的。若在每個定義的函式的後面均加上plt.show(),三幅影象是分開顯示的,無法連成一個整體。