1. 程式人生 > >pytorch fine-tune 預訓練的模型

pytorch fine-tune 預訓練的模型

之一:

torchvision 中包含了很多預訓練好的模型,這樣就使得 fine-tune 非常容易。本文主要介紹如何 fine-tune torchvision 中預訓練好的模型。

安裝

pip install torchvision

如何 fine-tune

以 resnet18 為例:

from torchvision import models
from torch import nn
from torch import optim

resnet_model = models.resnet18(pretrained=True) 
# pretrained 設定為 True,會自動下載模型 所對應權重,並載入到模型中
# 也可以自己下載 權重,然後 load 到 模型中,原始碼中有 權重的地址。

# 假設 我們的 分類任務只需要 分 100 類,那麼我們應該做的是
# 1. 檢視 resnet 的原始碼
# 2. 看最後一層的 名字是啥 (在 resnet 裡是 self.fc = nn.Linear(512 * block.expansion, num_classes))
# 3. 在外面替換掉這個層
resnet_model.fc= nn.Linear(in_features=..., out_features=100)

# 這樣就 哦了,修改後的模型除了輸出層的引數是 隨機初始化的,其他層都是用預訓練的引數初始化的。

# 如果只想訓練 最後一層的話,應該做的是:
# 1. 將其它層的引數 requires_grad 設定為 False
# 2. 構建一個 optimizer, optimizer 管理的引數只有最後一層的引數
# 3. 然後 backward, step 就可以了

# 這一步可以節省大量的時間,因為多數的引數不需要計算梯度
for para in list(resnet_model.parameters())[:-2]:
    para.requires_grad=False 

optimizer = optim.SGD(params=[resnet_model.fc.weight, resnet_model.fc.bias], lr=1e-3)

...

為什麼

這裡介紹下 執行resnet_model.fc= nn.Linear(in_features=..., out_features=100)時 框架內發生了什麼

這時應該看 nn.Module 原始碼的 __setattr__ 部分,因為 setattr 時都會呼叫這個方法:

def __setattr__(self, name, value):
    def remove_from(*dicts):
        for d in dicts:
            if name in d:
                del d[name]

首先映入眼簾就是 remove_from

 這個函式,這個函式的目的就是,如果出現了 同名的屬性,就將舊的屬性移除。 用剛才舉的例子就是:

  • 預訓練的模型中 有個 名字叫fc 的 Module。
  • 在類定義外,我們 將另一個 Module 重新 賦值給了 fc
  • 類定義內的 fc 對應的 Module 就會從 模型中 刪除。

之二:

前言

這篇文章算是論壇PyTorch Forums關於引數初始化和finetune的總結,也是我在寫程式碼中用的算是“最佳實踐”吧。最後希望大家沒事多逛逛論壇,有很多高質量的回答。

引數初始化

引數的初始化其實就是對引數賦值。而我們需要學習的引數其實都是Variable,它其實是對Tensor

的封裝,同時提供了datagrad等藉口,這就意味著我們可以直接對這些引數進行操作賦值了。這就是PyTorch簡潔高效所在。 這裡寫圖片描述  所以我們可以進行如下操作進行初始化,當然其實有其他的方法,但是這種方法是PyTorch作者所推崇的:

def weight_init(m):
# 使用isinstance來判斷m屬於什麼型別
    if isinstance(m, nn.Conv2d):
        n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
        m.weight.data.normal_(0, math.sqrt(2. / n))
    elif isinstance(m, nn.BatchNorm2d):
# m中的weight,bias其實都是Variable,為了能學習引數以及後向傳播
        m.weight.data.fill_(1)
        m.bias.data.zero_()

Finetune

往往在載入了預訓練模型的引數之後,我們需要finetune模型,可以使用不同的方式finetune。

區域性微調

有時候我們載入了訓練模型後,只想調節最後的幾層,其他層不訓練。其實不訓練也就意味著不進行梯度計算,PyTorch中提供的requires_grad使得對訓練的控制變得非常簡單。

model = torchvision.models.resnet18(pretrained=True)
for param in model.parameters():
    param.requires_grad = False
# 替換最後的全連線層, 改為訓練100類
# 新構造的模組的引數預設requires_grad為True
model.fc = nn.Linear(512, 100)

# 只優化最後的分類層
optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)

全域性微調

有時候我們需要對全域性都進行finetune,只不過我們希望改換過的層和其他層的學習速率不一樣,這時候我們可以把其他層和新層在optimizer中單獨賦予不同的學習速率。比如:

ignored_params = list(map(id, model.fc.parameters()))
base_params = filter(lambda p: id(p) not in ignored_params,
                     model.parameters())

optimizer = torch.optim.SGD([
            {'params': base_params},
            {'params': model.fc.parameters(), 'lr': 1e-3}
            ], lr=1e-2, momentum=0.9)

其中base_params使用1e-3來訓練,model.fc.parameters使用1e-2來訓練,momentum是二者共有的。

之三:

pytorch finetune模型

文章主要講述如何在pytorch上讀取以往訓練的模型引數,在模型的名字已經變更的情況下又如何讀取模型的部分引數等。                                                                                        --------作者:jiangwenj02【轉載請註明】

pytorch 模型的儲存與讀取

其中在模型的儲存過程有儲存模型和引數一起的也有單獨儲存模型引數的

單獨儲存模型引數

儲存時使用:

torch.save(the_model.state_dict(), PATH)

讀取時:

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

儲存模型與引數

儲存:

torch.save(the_model, PATH)

讀取:

the_model = torch.load(PATH)

模型的引數

fine-tune的過程是讀取原有模型的引數,但是由於模型的所要處理的資料集不同,最後的一層class的總數不同,所以需要修改模型的最後一層,這樣模型讀取的引數,和在大資料集上訓練好下載的模型引數在形式上不一樣。需要我們自己去寫函式讀取引數。

pytorch模型引數的形式

模型的引數是以字典的形式儲存的。

model_dict = the_model.state_dict(),
for k,v in model_dict.items():
    print(k)

即可看到所有的鍵值 如果想修改模型的引數,給相應的鍵值賦值即可

model_dict[k] = new_value

最後更新模型的引數

the_model.load_state_dict(model_dict)

如果模型的key值和在大資料集上訓練時的key值是一樣的

我們可以通過下列演算法進行讀取模型

model_dict = model.state_dict()

pretrained_dict = torch.load(model_path)
 # 1. filter out unnecessary keys
diff = {k: v for k, v in model_dict.items() if \
            k in pretrained_dict and pretrained_dict[k].size() == v.size()}
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
pretrained_dict.update(diff)
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(model_dict)

如果模型的key值和在大資料集上訓練時的key值是不一樣的,但是順序是一樣的

model_dict = model.state_dict()

pretrained_dict = torch.load(model_path)
keys = []
for k,v in pretrained_dict.items():
    keys.append(k)
i = 0
for k,v in model_dict.items():
    if v.size() == pretrained_dict[keys[i]].size():
        print(k, ',', keys[i])
         model_dict[k]=pretrained_dict[keys[i]]
    i = i + 1
model.load_state_dict(model_dict)

如果模型的key值和在大資料集上訓練時的key值是不一樣的,但是順序是也不一樣的

自己找對應關係,一個key對應一個key的賦值