1. 程式人生 > >Pytorch模型的儲存與載入

Pytorch模型的儲存與載入

前言

在使用Pytorch訓練模型的時候,經常會有在GPU上儲存模型然後再CPU上執行的需求,在實驗的過程中發現在多GPU上訓練的Pytorch模型是不能在CPU上直接執行的,幾次遇到了這種問題,這裡研究和記錄一下。

模型的儲存與載入

例如我們建立了一個模型:

model = MyVggNet()

如果使用多GPU訓練,我們需要使用這行程式碼:

model = nn.DataParallel(model).cuda()

執行這個程式碼之後,model就不在是我們原來的模型,而是相當於在我們原來的模型外面加了一層支援CPU執行的外殼,這時候真正的模型物件為:real_model = model.module

, 所以我們在儲存模型的時候注意,如果儲存的時候是否帶有這層加的外殼,如果儲存的時候帶有的話,載入的時候也是帶有的,如果儲存的是真實的模型,載入的也是真是的模型。這裡我建議儲存真是的模型,因為加了module殼的模型在CPU上是不能執行的。
Pytorch有多種儲存模型的方式,使用哪種進行儲存,就要使用對應的載入方式。儲存的時候模型的字尾名是無所謂的。
Pytorch官方的載入和儲存模型的方式有兩種:
1. 儲存和載入整個模型

torch.save(model_object, 'model.pkl')
model = torch.load('model.pkl')
  1. 僅儲存和載入模型引數(推薦使用)
torch.save(model_object.state_dict(), 'params.pkl')
model_object.load_state_dict(torch.load('params.pkl'))

模型儲存與載入對應方式

1. 第一種方式

儲存使用:

real_model = model.module
torch.save(real_model.state_dict(),os.path.join(args.save_path,"cos_mnist_"+str(epoch+1)+"_weight.pth"))

cpu上載入使用:

args.weight=checkpoint/cos_mnist_10_weight.pth
map_location = lambda
storage, loc: storage model.load_state_dict(torch.load(args.weight,map_location=map_location))

2. 第二種方式

儲存使用:

real_model = model.module
save_model(real_model, os.path.join(args.save_path,"cos_mnist_"+str(epoch+1)+"_weight_cpu.pth"))
# 自定義的函式
def save_model(model,filename):
    state = model.state_dict()
    for key in state: state[key] = state[key].clone().cpu()
    torch.save(state, filename)

cpu上載入使用:

args.weight=checkpoint/cos_mnist_10_weight_cpu.pth
model.load_state_dict(torch.load(args.weight))

3. 第三種方式

儲存使用:

real_model = model.module
torch.save(real_model, os.path.join(args.save_path,"cos_mnist_"+str(epoch+1)+"_whole.pth"))

cpu上載入使用:

args.weight=checkpoint/cos_mnist_10_whole.pth
map_location = lambda storage, loc: storage
model = torch.load(args.weight,map_location=map_location)

參考文獻

相關推薦

[PyTorch 學習筆記] 7.1 模型儲存載入

> 本章程式碼: > > - [https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson7/model_save.py](https://github.com/zhangxiann/PyTorch_Practice/b

TensorFlow實現Softmax迴歸(模型儲存載入

1 # -*- coding: utf-8 -*- 2 """ 3 Created on Thu Oct 18 18:02:26 2018 4 5 @author: zhen 6 """ 7 8 from tensorflow.examples.tutorials.mnist imp

python opencv3.x中支援向量機(svm)模型儲存載入問題

親自驗證,可以解決svm的模型載入問題:     import numpy as np     from sklearn import datasets         &nb

keras中訓練好的模型儲存載入

keras中的採用Sequential模式建立DNN並持久化保持、重新載入 def DNN_base_v1(X_train, y_train): model = models.Sequential() model.add(layers.Dense(96,

Keras 深度學習程式碼筆記——模型儲存載入

你可以使用model.save(filepath)將Keras模型和權重儲存在一個HDF5檔案中,該檔案將包含: 模型的結構,以便重構該模型 模型的權重 訓練配置(損失函式,優化器等) 優化器的狀態,以便於從上次訓練中斷的地方開始 使用keras.mod

tensorflow 模型儲存載入

在訓練一個神經網路模型後,你會儲存這個模型未來使用或部署到產品中。所以,什麼是TF模型?TF模型基本包含網路設計或圖,與訓練得到的網路引數和變數。因此,TF模型具有兩個主要檔案: a)meta圖 這是一個擬定的快取,包含了這個TF圖完整資訊;如所有變數等

Keras中的模型儲存載入

from keras.models import Sequential from keras.layers import Dense from keras.models import load_model model = Sequential() model.add(Dens

Pytorch模型儲存載入

前言 在使用Pytorch訓練模型的時候,經常會有在GPU上儲存模型然後再CPU上執行的需求,在實驗的過程中發現在多GPU上訓練的Pytorch模型是不能在CPU上直接執行的,幾次遇到了這種問題,這裡研究和記錄一下。 模型的儲存與載入 例如我們建立了一

【小白學PyTorch】19 TF2模型儲存載入

【新聞】:機器學習煉丹術的粉絲的人工智慧交流群已經建立,目前有目標檢測、醫學影象、時間序列等多個目標為技術學習的分群和水群嘮嗑的總群,歡迎大家加煉丹兄為好友,加入煉丹協會。微信:cyx645016617. 參考目錄: [TOC] 本文主要講述TF2.0的模型檔案的儲存和載入的多種方法。主要分成兩型別:模型

tensorflow模型儲存載入

1.儲存:(儲存的變數都是停放,tf.Variable()中的變數,變數一定要有名字) saver = tf.train.Saver() saver.run(sess,"./model4/line_model.ckpt")   2.檢視儲存的變數資訊:(將儲存的資訊打印

基於pytorch儲存載入模型引數

當我們花費大量的精力訓練完網路,下次預測資料時不想再(有時也不必再)訓練一次時,這時候torch.save(),torch.load()就要登場了。 儲存和載入模型引數有兩種方式: 方式一:   torch.save(net.state_dict(),path): 功能

Keras儲存載入模型(JSON+HDF5)

在Keras中,有時候需要對模型進行序列化與反序列化。進行模型序列化時,會將模型結果與模型權重儲存在不同的檔案中,模型權重通常儲存在HDF5檔案中,模型的結構可以儲存在JSON或者YAML檔案中。後二者方法大同小異,這裡以JSON為例說明一下Keras模型的儲存與載入。 from sklearn

python sklearn svm模型儲存載入呼叫

對於機器學習的一些模型,跑完之後,如果下一次測試又需要重新跑一遍模型是一件很繁瑣的事,這時候我們就需要儲存模型,再載入呼叫。 樓主發現有這些儲存模型的方法,網上有很多錯誤的例子,所以給大家在整理一下。(python3) 1.利用pickle import pickle

[TensorFlow深度學習入門]實戰八·簡便方法實現TensorFlow模型引數儲存載入(pb方式)

[TensorFlow深度學習入門]實戰八·簡便方法實現TensorFlow模型引數儲存與載入(pb方式) 在上篇博文中,我們探索了TensorFlow模型引數儲存與載入實現方法採用的是儲存ckpt的方式。這篇博文我們會使用儲存為pd格式檔案來實現。 首先,我會在上篇博文基礎上,實現由c

[TensorFlow深度學習入門]實戰七·簡便方法實現TensorFlow模型引數儲存載入(ckpt方式)

[TensorFlow深度學習入門]實戰七·簡便方法實現TensorFlow模型引數儲存與載入(ckpt方式) TensorFlow模型訓練的好網路引數如果想重複高效利用,模型引數儲存與載入是必須掌握的模組。本文提供一種簡單容易理解的方式來實現上述功能。參考部落格地址 備註: 本文采用的

模型儲存載入呼叫

模型儲存 BP: model.save(save_dir) SVM: from sklearn.externals import joblib joblib.dump(clf, save_dir) 模型載入 BP: from keras.models im

pytorch資料載入模型儲存載入

主要涉及的Pytorch官方示例下圖紅框部分的一些翻譯及備註。 1、資料載入及處理   該部分主要是用於進行資料集載入及資料預處理說明,使用的資料集為:人臉+標註座標。demo程式需要pandas(讀取CSV檔案)及scikit-image(影象變換)這兩個包。 1.1、jup

Keras 儲存載入網路模型

遇到問題: keras使用預訓練模型做訓練時遇到的如下程式碼: from keras.utils.data_utils import get_file WEIGHTS_PATH = 'https://github.com/fchollet/deep-lea

TensorFlow下網路模型儲存載入

#!/usr/bin/env python# 匯入mnist資料庫from tensorflow.examples.tutorials.mnist import input_datamnist = input_data.read_data_sets("MNIST_data/", one_hot=True)i

模型儲存載入

tensorflow: 有兩種方式儲存和載入模型。 ①生成checkpoint file,副檔名為.ckpt,通過在tf.train.Saver物件上呼叫Saver.save()生成。包含權重和變數,但不包括圖的結構。如果需要在另一個程式中使用,需要重新建立