1. 程式人生 > >PyTorch學習系列(十四)——儲存訓練好的模型

PyTorch學習系列(十四)——儲存訓練好的模型

PyTorch提供了兩種儲存訓練好的模型的方法。
第一種是隻儲存模型引數,這也是推薦的方法:

#儲存
torch.save(the_model.state_dict(), PATH)
#讀取
the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

第二種方法儲存整個模型:

#儲存
torch.save(the_model, PATH)
#讀取
the_model = torch.load(PATH)