1. 程式人生 > >pytorch筆記02)模型的儲存和載入

pytorch筆記02)模型的儲存和載入

儲存和載入整個模型

torch.save(model_object, 'model.pkl')
model = torch.load('model.pkl')

僅儲存和載入模型引數(推薦使用,需要提前手動構建模型)

torch.save(model_object.state_dict(), 'params.pkl')
model_object.load_state_dict(torch.load('params.pkl'))

但是要注意幾個細節:
1.若使用nn.DataParallel在一臺電腦上使用了多個GPU,load模型的時候也必須先DataParallel,這和keras類似。

2.load提供了很多過載的功能,其可以把在GPU上訓練的權重載入到CPU上跑。內容參考於:https://www.ptorch.com/news/74.html

torch.load('tensors.pt')
# 把所有的張量載入到CPU中
torch.load('tensors.pt', map_location=lambda storage, loc: storage)
# 把所有的張量載入到GPU 1中
torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1))
# 把張量從GPU 1 移動到 GPU 0
torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'})

在cpu上載入預先訓練好的GPU模型,有一種強制所有GPU張量在CPU中的方式:

torch.load('my_file.pt', map_location=lambda storage, loc: storage)

上述程式碼只有在模型在一個GPU上訓練時才起作用。如果我在多個GPU上訓練我的模型,儲存它,然後嘗試在CPU上載入,我得到這個錯誤:KeyError: ‘unexpected key “module.conv1.weight” in state_dict

’ 如何解決?
您可能已經使用模型儲存了模型nn.DataParallel,該模型將模型儲存在該模型中module,而現在您正試圖載入模型DataParallel。您可以nn.DataParallel在網路中暫時新增一個載入目的,也可以載入權重檔案,建立一個沒有module字首的新的有序字典,然後載入它。

# original saved file with DataParallel
state_dict = torch.load('myfile.pth.tar')
# create new OrderedDict that does not contain `module.`
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # remove `module.`
    new_state_dict[name] = v
# load params
model.load_state_dict(new_state_dict)

筆者封裝了一個簡單的函式,可以直接載入多GPU權重到CPU上(只加載匹配的權重)

# 載入模型,解決命名和維度不匹配問題,解決多個gpu並行
def load_state_keywise(model, model_path):
    model_dict = model.state_dict()
    pretrained_dict = torch.load(model_path, map_location='cpu')
    key = list(pretrained_dict.keys())[0]
    # 1. filter out unnecessary keys
    # 1.1 multi-GPU ->CPU
    if (str(key).startswith('module.')):
        pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items() if
                           k[7:] in model_dict and v.size() == model_dict[k[7:]].size()}
    else:
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if
                           k in model_dict and v.size() == model_dict[k].size()}
    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    # 3. load the new state dict
    model.load_state_dict(model_dict)

有朋友問,上面為什麼去掉‘’module.‘’就可以在單GPU上跑了,看下面的一個栗子

import torch
from torch import nn
import torchvision
#使用alexnet做測試,使用單個GPU或CUP 
alexnet=torchvision.models.alexnet()
state_dict=alexnet.state_dict()
for k, v in state_dict.items():
    print(k)
print("-"*20)  #華麗的分割線
#使用多GPU
model = nn.DataParallel(alexnet)
state_dict=model.state_dict()
for k, v in state_dict.items():
    print(k)

看結果就知道了,其就多了個字首‘module.’

features.0.weight
features.0.bias
features.3.weight
features.3.bias
features.6.weight
features.6.bias
features.8.weight
features.8.bias
features.10.weight
features.10.bias
classifier.1.weight
classifier.1.bias
classifier.4.weight
classifier.4.bias
classifier.6.weight
classifier.6.bias
--------------------
module.features.0.weight
module.features.0.bias
module.features.3.weight
module.features.3.bias
module.features.6.weight
module.features.6.bias
module.features.8.weight
module.features.8.bias
module.features.10.weight
module.features.10.bias
module.classifier.1.weight
module.classifier.1.bias
module.classifier.4.weight
module.classifier.4.bias
module.classifier.6.weight
module.classifier.6.bias