1. 程式人生 > >[PyTorch 學習筆記] 7.1 模型儲存與載入

[PyTorch 學習筆記] 7.1 模型儲存與載入

> 本章程式碼: > > - [https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson7/model_save.py](https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson7/model_save.py) > - [https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson7/model_load.py](https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson7/model_load.py) > - [https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson7/checkpoint_resume.py](https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson7/checkpoint_resume.py) > - [https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson7/save_checkpoint.py](https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson7/save_checkpoint.py) 這篇文章主要介紹了序列化與反序列化,以及 PyTorch 中的模型保存於載入的兩種方式,模型的斷點續訓練。 # 序列化與反序列化 模型在記憶體中是以物件的邏輯結構儲存的,但是在硬碟中是以二進位制流的方式儲存的。 - 序列化是指將記憶體中的資料以二進位制序列的方式儲存到硬碟中。PyTorch 的模型儲存就是序列化。 - 反序列化是指將硬碟中的二進位制序列載入到記憶體中,得到模型的物件。PyTorch 的模型載入就是反序列化。 # PyTorch 中的模型儲存與載入 ## torch.save ``` torch.save(obj, f, pickle_module, pickle_protocol=2, _use_new_zipfile_serialization=False) ``` 主要引數: - obj:儲存的物件,可以是模型。也可以是 dict。因為一般在儲存模型時,不僅要儲存模型,還需要儲存優化器、此時對應的 epoch 等引數。這時就可以用 dict 包裝起來。 - f:輸出路徑 其中模型儲存還有兩種方式: ### 儲存整個 Module 這種方法比較耗時,儲存的檔案大 ``` torch.savev(net, path) ``` ### 只儲存模型的引數 推薦這種方法,執行比較快,儲存的檔案比較小 ``` state_sict = net.state_dict() torch.savev(state_sict, path) ``` 下面是儲存 LeNet 的例子。在網路初始化中,把權值都設定為 2020,然後儲存模型。 ``` import torch import numpy as np import torch.nn as nn from common_tools import set_seed class LeNet2(nn.Module): def __init__(self, classes): super(LeNet2, self).__init__() self.features = nn.Sequential( nn.Conv2d(3, 6, 5), nn.ReLU(), nn.MaxPool2d(2, 2), nn.Conv2d(6, 16, 5), nn.ReLU(), nn.MaxPool2d(2, 2) ) self.classifier = nn.Sequential( nn.Linear(16*5*5, 120), nn.ReLU(), nn.Linear(120, 84), nn.ReLU(), nn.Linear(84, classes) ) def forward(self, x): x = self.features(x) x = x.view(x.size()[0], -1) x = self.classifier(x) return x def initialize(self): for p in self.parameters(): p.data.fill_(2020) net = LeNet2(classes=2019) # "訓練" print("訓練前: ", net.features[0].weight[0, ...]) net.initialize() print("訓練後: ", net.features[0].weight[0, ...]) path_model = "./model.pkl" path_state_dict = "./model_state_dict.pkl" # 儲存整個模型 torch.save(net, path_model) # 儲存模型引數 net_state_dict = net.state_dict() torch.save(net_state_dict, path_state_dict) ``` 執行完之後,資料夾中生成了``model.pkl`和`model_state_dict.pkl`,分別儲存了整個網路和網路的引數 ## torch.load ``` torch.load(f, map_location=None, pickle_module, **pickle_load_args) ``` 主要引數: - f:檔案路徑 - map_location:指定存在 CPU 或者 GPU。 載入模型也有兩種方式 ### 載入整個 Module 如果儲存的時候,儲存的是整個模型,那麼載入時就載入整個模型。這種方法不需要事先建立一個模型物件,也不用知道模型的結構,程式碼如下: ``` path_model = "./model.pkl" net_load = torch.load(path_model) print(net_load) ``` 輸出如下: ``` LeNet2( (features): Sequential( (0): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1)) (1): ReLU() (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (3): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1)) (4): ReLU() (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ) (classifier): Sequential( (0): Linear(in_features=400, out_features=120, bias=True) (1): ReLU() (2): Linear(in_features=120, out_features=84, bias=True) (3): ReLU() (4): Linear(in_features=84, out_features=2019, bias=True) ) ) ``` ### 只加載模型的引數 如果儲存的時候,儲存的是模型的引數,那麼載入時就引數。這種方法需要事先建立一個模型物件,再使用模型的`load_state_dict()`方法把引數載入到模型中,程式碼如下: ``` path_state_dict = "./model_state_dict.pkl" state_dict_load = torch.load(path_state_dict) net_new = LeNet2(classes=2019) print("載入前: ", net_new.features[0].weight[0, ...]) net_new.load_state_dict(state_dict_load) print("載入後: ", net_new.features[0].weight[0, ...]) ``` # 模型的斷點續訓練 在訓練過程中,可能由於某種意外原因如斷點等導致訓練終止,這時需要重新開始訓練。斷點續練是在訓練過程中每隔一定次數的 epoch 就儲存**模型的引數和優化器的引數**,這樣如果意外終止訓練了,下次就可以重新載入最新的**模型引數和優化器的引數**,在這個基礎上繼續訓練。 下面的程式碼中,每隔 5 個 epoch 就儲存一次,儲存的是一個 dict,包括模型引數、優化器的引數、epoch。然後在 epoch 大於 5 時,就`break`模擬訓練意外終止。關鍵程式碼如下: ``` if (epoch+1) % checkpoint_interval == 0: checkpoint = {"model_state_dict": net.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "epoch": epoch} path_checkpoint = "./checkpoint_{}_epoch.pkl".format(epoch) torch.save(checkpoint, path_checkpoint) ``` 在 epoch 大於 5 時,就`break`模擬訓練意外終止 ``` if epoch > 5: print("訓練意外中斷...") break ``` 斷點續訓練的恢復程式碼如下: ``` path_checkpoint = "./checkpoint_4_epoch.pkl" checkpoint = torch.load(path_checkpoint) net.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) start_epoch = checkpoint['epoch'] scheduler.last_epoch = start_epoch ``` 需要注意的是,還要設定`scheduler.last_epoch`引數為儲存的 epoch。模型訓練的起始 epoch 也要修改為儲存的 epoch。 **參考資料** - [深度之眼 PyTorch 框架班](https://ai.deepshare.net/detail/p_5df0ad9a09d37_qYqVmt85/6)
如果你覺得這篇文章對你有幫助,不妨點個贊,讓我有更多動力寫出好文章