1. 程式人生 > >pytorch學習筆記(十一):fine-tune 預訓練的模型

pytorch學習筆記(十一):fine-tune 預訓練的模型

torchvision 中包含了很多預訓練好的模型,這樣就使得 fine-tune 非常容易。本文主要介紹如何 fine-tune torchvision 中預訓練好的模型。

安裝

pip install torchvision

如何 fine-tune

以 resnet18 為例:

from torchvision import models
from torch import nn
from torch import optim

resnet_model = models.resnet18(pretrained=True) 
# pretrained 設定為 True,會自動下載模型 所對應權重,並載入到模型中
# 也可以自己下載 權重,然後 load 到 模型中,原始碼中有 權重的地址。 # 假設 我們的 分類任務只需要 分 100 類,那麼我們應該做的是 # 1. 檢視 resnet 的原始碼 # 2. 看最後一層的 名字是啥 (在 resnet 裡是 self.fc = nn.Linear(512 * block.expansion, num_classes)) # 3. 在外面替換掉這個層 resnet_model.fc= nn.Linear(in_features=..., out_features=100) # 這樣就 哦了,修改後的模型除了輸出層的引數是 隨機初始化的,其他層都是用預訓練的引數初始化的。
# 如果只想訓練 最後一層的話,應該做的是: # 1. 將其它層的引數 requires_grad 設定為 False # 2. 構建一個 optimizer, optimizer 管理的引數只有最後一層的引數 # 3. 然後 backward, step 就可以了 # 這一步可以節省大量的時間,因為多數的引數不需要計算梯度 for para in list(resnet_model.parameters())[:-2]: para.requires_grad=False optimizer = optim.SGD(params=[resnet_model.fc.weight, resnet_model.fc.bias], lr=1e-3
) ...

為什麼

這裡介紹下 執行resnet_model.fc= nn.Linear(in_features=..., out_features=100)時 框架內發生了什麼

這時應該看 nn.Module 原始碼的 __setattr__ 部分,因為 setattr 時都會呼叫這個方法:

def __setattr__(self, name, value):
    def remove_from(*dicts):
        for d in dicts:
            if name in d:
                del d[name]

首先映入眼簾就是 remove_from 這個函式,這個函式的目的就是,如果出現了 同名的屬性,就將舊的屬性移除。 用剛才舉的例子就是:

  • 預訓練的模型中 有個 名字叫fc 的 Module。
  • 在類定義外,我們 將另一個 Module 重新 賦值給了 fc
  • 類定義內的 fc 對應的 Module 就會從 模型中 刪除。