Pytorch 儲存和載入模型 part2
阿新 • • 發佈:2018-11-03
搭建網路:
torch.manual_seed(1) # reproducible # 假資料 x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # x data (tensor), shape=(100, 1) y = x.pow(2) + 0.2*torch.rand(x.size()) # noisy y data (tensor), shape=(100, 1) x, y = Variable(x, requires_grad=False), Variable(y, requires_grad=False) def save(): # 建網路 net1 = torch.nn.Sequential( torch.nn.Linear(1, 10), torch.nn.ReLU(), torch.nn.Linear(10, 1) ) optimizer = torch.optim.SGD(net1.parameters(), lr=0.5) loss_func = torch.nn.MSELoss() # 訓練 for t in range(100): prediction = net1(x) loss = loss_func(prediction, y) optimizer.zero_grad() loss.backward() optimizer.step()
兩種儲存途徑:
torch.save(net1, 'net.pkl') # 儲存整個網路
torch.save(net1.state_dict(), 'net_params.pkl') # 只儲存網路中的引數 (速度快, 佔記憶體少)
兩種提取方法:
1 提取整個網路:
def restore_net():
# restore entire net1 to net2
net2 = torch.load('net.pkl')
prediction = net2(x)
2 只提取網路引數:
def restore_params(): # 新建 net3 net3 = torch.nn.Sequential( torch.nn.Linear(1, 10), torch.nn.ReLU(), torch.nn.Linear(10, 1) ) # 將儲存的引數複製到 net3 net3.load_state_dict(torch.load('net_params.pkl')) prediction = net3(x)
儲存顯示檢視
# 儲存 net1 (1. 整個網路, 2. 只有引數)
save()
# 提取整個網路
restore_net()
# 提取網路引數, 複製到新網路
restore_params(
本文參考:https://morvanzhou.github.io/tutorials/machine-learning/torch/3-04-save-reload/