1. 程式人生 > >pytorch資料載入、模型儲存及載入

pytorch資料載入、模型儲存及載入

主要涉及的Pytorch官方示例下圖紅框部分的一些翻譯及備註。
在這裡插入圖片描述

1、資料載入及處理

  該部分主要是用於進行資料集載入及資料預處理說明,使用的資料集為:人臉+標註座標。demo程式需要pandas(讀取CSV檔案)及scikit-image(影象變換)這兩個包。

1.1、jupyter顯示matplot影象

import matplotlib.pyplot as plt
%matplotlib inline   #這句是在jupyter顯示影象的關鍵,在其它IDE中必須註釋掉,否則報錯  

1.2、資料集類

  torch.utils.data.Dataset

是一個處理資料的抽象類。當使用自己的資料集時需要繼承Dataset類,並且過載以下成員函式:
《1》、len : 用於返回資料集的大小。
《2》、getitem : 通過下標索引取得第i個樣本。

demo程式中為臉部標註樣本建立了一個FaceLandmarksDataset類。在該類的__init__方法中讀取csv檔案,在__getitem__方法中載入圖片。
我們建立的資料集樣本會以一個字典表示,如下:

{'image': image, 'landmarks': landmarks}

該資料集類有一個可選引數“transform”,用於控制對影象進行的處理。

1.3、資料轉換(transforms)

  幾乎所有神經網路的輸入都希望接收到大小固定的資料,而我們demo中的原始影象大小是不一致的。因此我們新增一些影象變換方法來處理這些影象。包括以下三個:
  Rescale: 縮放圖片
  RandomCrop: 隨機裁剪
  ToTensor:將numpy表示的影象轉化為torch的Tensor表示
將每一個影象變換用一個可呼叫的類實現。這樣做的好處是–進行變換時的引數不用每次都在迭代上下文傳遞。為此實現了類的__call__ 專有函式

__call__、__getitem__
python專有函式。
若在定義類的時候,實現__call__函式,則這個類就成為可呼叫的。換句話說,我們可以把這個類的例項當做函式來使用。
相當於過載了括號運算子。

例子說明:

class g_dpm(object):
    def __init__(self, g):
        self.g = g

    def __call__(self, t):
        return (self.g*t**2)/2

  計算地球場景的時候,我們就可以令e_dpm = g_dpm(9.8),s = e_dpm(t)

1.4、組合影象變換

  我們需要將demo中使用的樣本圖片較短的邊設定為256,之後將圖片裁剪為244x244大小。為此我們需要組合Rescale和RandomCrop兩種變換。可以通過torchvision.transforms.Compose實現。這是一個可呼叫類。

scale = Rescale(256)
crop = RandomCrop(128)
composed = transforms.Compose([Rescale(256),
                               RandomCrop(224)])

1.5、在資料集上進行迭代處理

   如果我們在訓練的每次迭代中只從資料集中抽取一張圖片,那麼我們會丟失很多和資料集有關的特徵。因此我們每次迭代我們採用以下方法:
  《1》、批處理
  《2》、資料重混(打亂資料)
  《3》、用多個程序並行載入資料
原文:
在這裡插入圖片描述
torch提供了torch.utils.data.DataLoader用於實現以上3點。torch.utils.data.DataLoader是一個迭代器,他有一個collate_fn的引數需要特別留意下,該引數用於合併一些list形式樣本來形成一個小批量( merges a list of samples to form a mini-batch)。
原文:
在這裡插入圖片描述

dataloader = DataLoader(transformed_dataset, batch_size=4,
                        shuffle=True, num_workers=4)

1.6、Torchvision

  torchvision包提供了一些常用的資料集類和資料處理實現。該包中最常用的資料集類是ImageFolder。
該類假設圖片按照以下方式儲存。
在這裡插入圖片描述
上圖中bees、ants都是類標籤。ImageFloder使用例子:

import torch
from torchvision import transforms, datasets

data_transform = transforms.Compose([
        transforms.RandomSizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
hymenoptera_dataset = datasets.ImageFolder(root='hymenoptera_data/train',
                                           transform=data_transform)
dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,
                                             batch_size=4, shuffle=True,
                                             num_workers=4)

2、模型載入及儲存

2.1、與模型載入有關的3個函式:

《1》、torch.save

torch.save(obj, f, pickle_module=<module 'pickle' from '/scratch/rzou/pt/release-env/lib/python3.7/pickle.py'>, pickle_protocol=2)

功能:將模型儲存到磁碟。該函式使用python的pickle包來序列化模型。
官方推薦兩種用法:
A、僅僅儲存模型引數;
B、儲存整個模型;
在這裡插入圖片描述
《2》、torch.load

torch.load(f, map_location=None, pickle_module=<module 'pickle' from '/scratch/rzou/pt/release-env/lib/python3.7/pickle.py'>)

功能:載入由torch.load()函式儲存的模型。
  該函式首先會將模型反序列化到CPU然後將模型移動到儲存模型時該模型所處的裝置(CPU或GPU)。如果現有機器上沒有對應儲存模型時的裝置,則該函式會丟擲異常。如果遇到這種情況,可以使用該函式的map_location引數來將模型動態對映到一系列裝置上。
在這裡插入圖片描述

>>> torch.load('tensors.pt')
# Load all tensors onto the CPU
>>> torch.load('tensors.pt', map_location=torch.device('cpu'))
# Load all tensors onto the CPU, using a function
>>> torch.load('tensors.pt', map_location=lambda storage, loc: storage)
# Load all tensors onto GPU 1
>>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1))
# Map tensors from GPU 1 to GPU 0
>>> torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'})
# Load tensor from io.BytesIO object
>>> with open('tensor.pt') as f:
        buffer = io.BytesIO(f.read())
>>> torch.load(buffer)

《3》、torch.nn.Module.load_state_dict

load_state_dict(state_dict, strict=True)

功能:僅僅載入模型的引數。
在這裡插入圖片描述

2.2、STATE_DICT

  在pytorch中,torch.nn.model中的可學習引數(權重,bias(偏差)等)都儲存在模型的parameters成員中,可通過model.parameters()獲取。
stat_dict是一個字典,該字典包含model每一層的tensor型別的可學習引數。只有包含可學習引數的網路層才能將其引數對映到state_dict字典中。
原文:
在這裡插入圖片描述

例子:

#定義網路模型用於說明sate_dict
class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        self.cov1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        
    def forward():
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = TheModelClass()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

輸出:
在這裡插入圖片描述
由列印知道模型引數包括兩大類。一類是權重及偏差引數,另一類是Optimizer引數。optimizer的state_dict包含兩個關鍵字:優化器的state及超引數。

2.3、儲存模型及載入模型用於預測

a、儲存
推薦僅僅儲存模型的state_dict

torch.save(model.state_dict(), MODELPATH)

b、載入

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

Pytorch儲存的模型字尾一般是.pt或者.pth
必須在載入模型後呼叫model.eval函式來將dropout及批歸一化層設定為預測模式。如果不這麼做結果出錯。

2.4、儲存臨時模型用於預測或再訓練

a、儲存

torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            ...
            }, PATH)

  當儲存一個臨時模型用於預測或再訓練時,需要儲存比state_dict更多的引數。包括優化器的state_dict,迭代次數epoch,最後一層迭代的loss及其他任何需要的引數。
  當儲存多個元件時,將多個元件以字典的形式組織,然後用torch.savee()來序列化該字典。在Pytorch中常用.tar檔案字尾表示這種模型。

b、載入

model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.eval()   #預測
# - or -
model.train() #再訓練

2.5、將多個模型儲存在一個檔案中

a、儲存:

torch.save({
            'modelA_state_dict': modelA.state_dict(),
            'modelB_state_dict': modelB.state_dict(),
            'optimizerA_state_dict': optimizerA.state_dict(),
            'optimizerB_state_dict': optimizerB.state_dict(),
            ...
            }, PATH)

b、載入:

modelA = TheModelAClass(*args, **kwargs)
modelB = TheModelBClass(*args, **kwargs)
optimizerA = TheOptimizerAClass(*args, **kwargs)
optimizerB = TheOptimizerBClass(*args, **kwargs)

checkpoint = torch.load(PATH)
modelA.load_state_dict(checkpoint['modelA_state_dict'])
modelB.load_state_dict(checkpoint['modelB_state_dict'])
optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])

modelA.eval()
modelB.eval()
# - or -
modelA.train()
modelB.train()

2.6、利用一個不同的模型來預熱(warmstarting)待使用的模型【WARMSTARTING MODEL USING PARAMETERS FROM A DIFFERENT MODEL】

在這裡插入圖片描述
載入一個模型的一部分或者載入一個不完整的網路在遷移學習或者訓練一個新的複雜網路時會經常遇到。
使用已經訓練過的引數,即使這些引數僅僅是待訓練網路引數的一小部分,也會加快網路的訓練及幫助網路更快達到收斂。

2.7、在不同裝置上進行模型的儲存及載入

《1》、GPU上儲存,CPU上載入
在這裡插入圖片描述
在這種情況下,tensor的使用的記憶體會自動重對映到CPU裝置中。

《2》、GPU上儲存,GPU上載入
在這裡插入圖片描述
該場景下需要注意:必須將模型的所有輸入使用.to(torch.device(“cuda”))轉為GPU使用的型別。
注意:
  my_tensor.to(device)返回的是my_tensor的一個新的拷貝,該操作不會覆蓋my_tensor原本的device型別(CPU或GPU)
覆蓋式的寫法:

 my_tensor = my_tensor.to(device)

《3》、CPU上儲存,GPU上載入
在這裡插入圖片描述
比起型別2,在呼叫load_state_dict函式時多一個map_loaction操作。其它操作同類型2.
《4》、儲存並行資料模型(torch.nn.DataParallel)
在這裡插入圖片描述
torch.nn.DataParallel模型是一個封裝好的模型,該模型能使用GPU的並行處理操作。

3、Jupyter demo