1. 程式人生 > >儲存和恢復神經網路

儲存和恢復神經網路

轉自莫煩大神,轉載原因是想把所有相關內容收集到自己的部落格中,方便系統的學習。

兩種儲存方法,1是儲存整個神經網路;2是隻儲存神經網路的所有引數。

一、儲存神經網路

1儲存整個神經網路。

 torch.save(net1,"net1.pkl")

net1為我想要儲存的網路,net1.pkl為檔名,儲存的格式只能是.pkl

2,儲存神經網路引數

torch.save(net1.state_dict(),"net1_parmaer.pkl") 

二、恢復神經網路

1恢復完整神經網路(直接load())

net2=torch.load("net1.pkl
")

2.從引數中恢復神經網路

需先構建與所要恢復的神經網路相同結構,再load引數。

3,完整程式如下

import torch
import matplotlib.pyplot as plt
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)  # x data (tensor), shape=(100, 1)
y = x.pow(2) + 0.2*torch.rand(x.size())  # noisy y data (tensor), shape=(100, 1)

def save():
    net1 = torch.nn.Sequential(
        torch.nn.Linear(
1, 10), # 一層神經層 torch.nn.ReLU(), # 加激勵函式,relu相當於類 torch.nn.Linear(10, 1), ) optimizer=torch.optim.SGD(net1.parameters(),lr=0.5) loss_func=torch.nn.MSELoss() for t in range(100): prediction=net1(x) loss=loss_func(prediction,y) optimizer.zero_grad() loss.backward() optimizer.step()
#畫圖 plt.figure(1, figsize=(10, 3)) plt.subplot(131) plt.title('Net1') plt.scatter(x.data.numpy(), y.data.numpy()) #實際資料 plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) #迴歸曲線 torch.save(net1,"net1.pkl") #儲存整個神經網路 torch.save(net1.state_dict(),"net1_parmaer.pkl") #儲存神經網路中的所有引數 def restore_net(): net2=torch.load("net1.pkl") prediction2=net2(x) plt.subplot(132) plt.title('Net2') plt.scatter(x.data.numpy(), y.data.numpy()) #實際資料 plt.plot(x.data.numpy(), prediction2.data.numpy(), 'r-', lw=5) #迴歸曲線 def restore_paramers(): net3=torch.nn.Sequential( torch.nn.Linear(1, 10), # 一層神經層 torch.nn.ReLU(), # 加激勵函式,relu相當於類 torch.nn.Linear(10, 1), ) net3.load_state_dict(torch.load("net1_parmaer.pkl")) #先構建網路在,再載入引數 prediction3 = net3(x) plt.subplot(133) plt.title('Net3') plt.scatter(x.data.numpy(), y.data.numpy()) # 實際資料 plt.plot(x.data.numpy(), prediction3.data.numpy(), 'r-', lw=5) # 迴歸曲線 plt.show() save() restore_net()

執行結果: