1. 程式人生 > >[Pytorch]Pytorch 保存模型與加載模型(轉)

[Pytorch]Pytorch 保存模型與加載模型(轉)

filter time class ada enume req update ems 實現

轉自:知乎

目錄:

  • 保存模型與加載模型
  • 凍結一部分參數,訓練另一部分參數
  • 采用不同的學習率進行訓練

1.保存模型與加載

簡單的保存與加載方法:

# 保存整個網絡
torch.save(net, PATH)
# 保存網絡中的參數, 速度快,占空間少
torch.save(net.state_dict(),PATH)
#--------------------------------------------------
#針對上面一般的保存方法,加載的方法分別是:
model_dict=torch.load(PATH)
model_dict=model.load_state_dict(torch.load(PATH
))


然而,在實驗中往往需要保存更多的信息,比如優化器的參數,那麽可以采取下面的方法保存:

torch.save({‘epoch‘: epochID + 1, ‘state_dict‘: model.state_dict(), ‘best_loss‘: lossMIN,
‘optimizer‘: optimizer.state_dict(),‘alpha‘: loss.alpha, ‘gamma‘: loss.gamma},
checkpoint_path + ‘/m-‘ + launchTimestamp + ‘-‘ + str("%.4f" % lossMIN) + ‘.pth.tar‘)

以上包含的信息有,epochID, state_dict, min loss, optimizer, 自定義損失函數的兩個參數;格式以字典的格式存儲。

加載的方式:

def load_checkpoint(model, checkpoint_PATH, optimizer):
if checkpoint != None:
model_CKPT = torch.load(checkpoint_PATH)
model.load_state_dict(model_CKPT[‘state_dict‘])
print(‘loading checkpoint!‘)
optimizer.load_state_dict(model_CKPT[‘optimizer‘])
return model, optimizer

其他的參數可以通過以字典的方式獲得

但是,但是,我們可能修改了一部分網絡,比如加了一些,刪除一些,等等,那麽需要過濾這些參數,加載方式:

def load_checkpoint(model, checkpoint, optimizer, loadOptimizer):
if checkpoint != ‘No‘:
print("loading checkpoint...")
model_dict = model.state_dict()
modelCheckpoint = torch.load(checkpoint)
pretrained_dict = modelCheckpoint[‘state_dict‘]
# 過濾操作
new_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()}
model_dict.update(new_dict)
# 打印出來,更新了多少的參數
print(‘Total : {}, update: {}‘.format(len(pretrained_dict), len(new_dict)))
model.load_state_dict(model_dict)
print("loaded finished!")
# 如果不需要更新優化器那麽設置為false
if loadOptimizer == True:
optimizer.load_state_dict(modelCheckpoint[‘optimizer‘])
print(‘loaded! optimizer‘)
else:
print(‘not loaded optimizer‘)
else:
print(‘No checkpoint is included‘)
return model, optimizer

2.凍結部分參數,訓練另一部分參數

1)添加下面一句話到模型中

for p in self.parameters():
p.requires_grad = False

比如加載了resnet預訓練模型之後,在resenet的基礎上連接了新的模快,resenet模塊那部分可以先暫時凍結不更新,只更新其他部分的參數,那麽可以在下面加入上面那句話

class RESNET_MF(nn.Module):
def init(self, model, pretrained):
super(RESNET_MF, self).__init__()
self.resnet = model(pretrained)
for p in self.parameters():
p.requires_grad = False
self.f = SpectralNorm(nn.Conv2d(2048, 512, 1))
self.g = SpectralNorm(nn.Conv2d(2048, 512, 1))
self.h = SpectralNorm(nn.Conv2d(2048, 2048, 1))
...

同時在優化器中添加:filter(lambda p: p.requires_grad, model.parameters())

optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001, betas=(0.9, 0.999),
eps=1e-08, weight_decay=1e-5)

2) 參數保存在有序的字典中,那麽可以通過查找參數的名字對應的id值,進行凍結

查找的代碼:

 model_dict = torch.load(‘net.pth.tar‘).state_dict()
dict_name = list(model_dict)
for i, p in enumerate(dict_name):
print(i, p)

保存一下這個文件,可以看到大致是這個樣子的:

0 gamma
1 resnet.conv1.weight
2 resnet.bn1.weight
3 resnet.bn1.bias
4 resnet.bn1.running_mean
5 resnet.bn1.running_var
6 resnet.layer1.0.conv1.weight
7 resnet.layer1.0.bn1.weight
8 resnet.layer1.0.bn1.bias
9 resnet.layer1.0.bn1.running_mean
....

同樣在模型中添加這樣的代碼:

for i,p in enumerate(net.parameters()):
if i < 165:
p.requires_grad = False

在優化器中添加上面的那句話可以實現參數的屏蔽

[Pytorch]Pytorch 保存模型與加載模型(轉)