1. 程式人生 > >小白程式設計用Pytorch匯入預訓練模型&&設定不同學習速率

小白程式設計用Pytorch匯入預訓練模型&&設定不同學習速率

前兩天正好在做這個部分,參考了很多網友的做法,也去pytorch論壇查了一下,現在總結如下。建議還是自己單步除錯一下看看每個引數裡面的值是什麼樣的比較好。

1.匯入預訓練的模型,預訓練模型是現有模型的一個或者幾個部分

假設我有一個網路包含 pretrained和classify兩個部分,每個部分分別有一些卷積層or迴歸層,pretrained部分有一個已經訓練好的網路模型pretrained model,那麼我需要把這個網路模型匯入到現有的網路中,實現程式碼如下:

# load pretrained model
pretrained_net = PretrainedNet()
pretrained_net.load_state_dict(torch.load('epochs/epoch_3_100.pt'))
pretrained_dict = pretrained_net.state_dict() 

# prepare my model_dict
model = MyNet()
model_dict = model.state_dict()

# trained_part 是pretrained部分在現有網路的名稱,個人喜好把網路標記的明確一些,所以pretrained的部分都會寫成一個pretrained類,然後呼叫,這樣這個子塊的每個引數名稱就變成pretrained.xx.weight這樣,有的人喜歡直接把pretrained部分寫成跟pretrained model一樣的引數名稱,都OK,去掉pretrained_part這個變數就好。
pretrained_part = 'pretrained.'
dict_temp = {pretrained_part + k: v for k, v in pretrained_dict.items() if pretrained_part + k in model_dict}
model_dict.update(dict_temp)
model.load_state_dict(model_dict)

這部分一般是在train的時候,例項化訓練網路之後,匯入預訓練模型。匯入模型後,有時候需要固定這部分的引數或者給他們一個很低的學習速率,這時候就要開始給設定不同的學習速率。

2.給不同的部分設定學習速率。

固定pretrained的引數,僅僅訓練classify

# setting the pretrained part leaning rate as zero, Only train the classifier part

for param in list(model.pretrained.parameters()):
    param.requires_grad = False

params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = optim.Adam(params, lr=1e-3)

分別給pretrained和classify部分設定不同的學習速率

# pretrained_params 是給filter的一個list,用來過濾,其中的值是int,所以在optimizer設定引數的時候,不能用pretrained_params,應當直接使用model.pretrained.parameters()

pretrained_params = list(map(id, model.pretrained.parameters()))
classify_params = filter(lambda p: id(p) not in pretrained_params, model.parameters())

optimizer = optim.Adam([{'params': classify_params},
            {'params': model.pretrained.parameters(), 'lr': 1e-5}], lr=1e-3)

#設定學習速率的step方式
scheduler = MultiStepLR(optimizer, milestones=[30, 80], gamma=0.1) 
model.children_model.paramters()可以直接遍歷子網路裡面的parameter,不用把每個卷積層的parameter都列出來。
pytorch還在摸索中,自己實現個程式碼對程式設計小白來說還是比較輕鬆的,現在使用到什麼學什麼,勉強夠用吧,版本0.3.1。因為聽說0.4有大改動,擔心之前的程式碼不能用,暫時不升級了。