1. 程式人生 > >【pytorch】載入模型出現的bug

【pytorch】載入模型出現的bug

在模型訓練完後再進行測試載入模型後出現bug,顯示如下錯誤

 

據瞭解是由於pytorch版本導致的錯誤,可能與自己訓練階段保持的模型方式有關,訓練階段儲存方式如下:

解決方案如下:

方法一:

generator.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(generator_1_10.pth).items()})

實際上就是將load進行的權重的有序字典裡面的鍵值前面的的7個字元’module.'去掉。載入進行的權重有序字典如下圖所示:

 鍵就是每層的權重或者 bias 的名稱,value就是其具體的張量值。

方法二:重新新建個有序字典:

from collections import OrderedDict
    #     new_state_dict = OrderedDict()
    #     for k, v in a.items():
    #         name=k[7:]  # reduce `module.`
    #         new_state_dict[name] = v
    #     # load params
    #     # model.load_state_dict(new_state_dict)
    #     model.load_state_dict(new_state_dict)

顯然方法一更簡潔明瞭。