1. 程式人生 > >Pytorch之儲存讀取模型

Pytorch之儲存讀取模型

目錄

轉自這裡

pytorch儲存資料

pytorch儲存資料的格式為.t7檔案或者.pth檔案t7檔案是沿用torch7中讀取模型權重的方式。而pth檔案是python中儲存檔案的常用格式。而在keras中則是使用.h5檔案。

# 儲存模型示例程式碼
print('===> Saving models...')
state = {
    'state': model.state_dict(),
    'epoch': epoch                   # 將epoch一併儲存
}
if not os.path.isdir('checkpoint'):
    os.mkdir('checkpoint')
torch.save(state, './checkpoint/autoencoder.t7')

儲存用到torch.save函式,注意該函式第一個引數可以是單個值也可以是字典,字典可以存更多你要儲存的引數(不僅僅是權重資料)。

pytorch讀取資料

pytorch讀取資料使用的方法和我們平時使用預訓練引數所用的方法是一樣的,都是使用load_state_dict這個函式。

下方的程式碼和上方的儲存程式碼可以搭配使用。

print('===> Try resume from checkpoint')
if os.path.isdir('checkpoint'):
    try:
        checkpoint = torch.load('./checkpoint/autoencoder.t7')
        model.load_state_dict(checkpoint['state'])        # 從字典中依次讀取
        start_epoch = checkpoint['epoch']
        print('===> Load last checkpoint data')
    except FileNotFoundError:
        print('Can\'t found autoencoder.t7')
else:
    start_epoch = 0
    print('===> Start from scratch')

以上是pytorch讀取的方法彙總,但是要注意,在使用官方的預處理模型進行讀取時,一般使用的格式是pth,使用官方的模型讀取命令會檢查你模型的格式是否正確,如果不是使用官方提供模型通過下面的函式強行讀取模型(將其他模型例如caffe模型轉過來的模型放到指定目錄下)會發生錯誤。

def vgg19(pretrained=False, **kwargs):
    """VGG 19-layer model (configuration "E")

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = VGG(make_layers(cfg['E']), **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['vgg19']))
    return model

假如我們有從caffe模型轉過來的pytorch模型([0-255,BGR]),我們可以使用:

model_dir = '自己的模型地址'
model = VGG()
model.load_state_dict(torch.load(model_dir + 'vgg_conv.pth'))

也就是pytorch的讀取函式進行讀取即可。