1. 程式人生 > >[ pytorch ] ——基本使用:(2) 訓練好的模型引數的儲存以及呼叫

[ pytorch ] ——基本使用:(2) 訓練好的模型引數的儲存以及呼叫

1、儲存與呼叫


def modelfunc(nn.Module):
    # 之前定義好的模型

# 由於pytorch沒有像keras那樣有儲存模型結構的API,因此,每次load之前必須找到模型的結構。

model_object = modelfunc # 匯入模型結構

# 儲存和載入整個模型  
torch.save(model_object, 'model.pth')  
model = torch.load('model.pth')  
     
# 僅儲存和載入模型引數  
torch.save(model_object.state_dict(), 'params.pth')  
model_object.load_state_dict(torch.load('params.pth'))  

2、torch.load 的輸出:


# 儲存和載入整個模型  
torch.save(model_object, 'model.pth')  
model = torch.load('model.pth')  
print(model)

>>>【結果】
modelfunc(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): Bottleneck(
      (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
    )
    (2): Bottleneck(
      (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
    )
  )
  (layer2): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (downsample): Sequential(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): Bottleneck(
      (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
    )
    (2): Bottleneck(
      (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
    )
    (3): Bottleneck(
      (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
    )
  )
)

3、model_object.load_state_dict(torch.load('params.pth')) 引數的輸出


# 僅儲存和載入模型引數  
torch.save(model_object.state_dict(), 'params.pth')  
dic = model_object.load_state_dict(torch.load('params.pth'))  
print(dic)

>>>【結果】

OrderedDict([('conv1.weight', tensor([[[[ 1.3335e-02,  1.4664e-02, -1.5351e-02,  ..., -4.0896e-02,
           -4.3034e-02, -7.0755e-02],
          [ 4.1205e-03,  5.8477e-03,  1.4948e-02,  ...,  2.2060e-03,
           -2.0912e-02, -3.8517e-02],
          [ 2.2331e-02,  2.3595e-02,  1.6120e-02,  ...,  1.0281e-01,
            6.2641e-02,  5.1977e-02],
          ...,
('bn1.weight', tensor([ 2.3888e-01,  2.9136e-01,  3.1615e-01,  2.7122e-01,  2.1731e-01,
         3.0903e-01,  2.2937e-01,  2.3086e-01,  2.1129e-01,  2.8054e-01,
         1.9923e-01,  3.1894e-01,  1.7955e-01,  1.1246e-08,  1.9704e-01,
         2.0996e-01,  2.4317e-01,  2.1697e-01,  1.9415e-01,  3.1569e-01,
         1.9648e-01,  2.3214e-01,  2.1962e-01,  2.1633e-01,  2.4357e-01,
         2.9683e-01,  2.3852e-01,  2.1162e-01,  1.4492e-01,  2.9388e-01,
         2.2911e-01,  9.2716e-02,  4.3334e-01,  2.0782e-01,  2.7990e-01,
         3.5804e-01,  2.9315e-01,  2.5306e-01,  2.4210e-01,  2.1755e-01,
         3.8645e-01,  2.1003e-01,  3.6805e-01,  3.3724e-01,  5.0826e-01,
         1.9341e-01,  2.3914e-01,  2.6652e-01,  3.9020e-01,  1.9840e-01,
         2.1694e-01,  2.6666e-01,  4.9806e-01,  2.3553e-01,  2.1349e-01,
         2.5951e-01,  2.3547e-01,  1.7579e-01,  4.5354e-01,  1.7102e-01,
         2.4903e-01,  2.5148e-01,  3.8020e-01,  1.9665e-01])), 
('bn1.bias', tensor([ 2.2484e-01,  6.0617e-01,  1.2483e-02,  1.3270e-01,  1.8030e-01,
         1.4739e-01,  1.7430e-01,  1.9023e-01,  2.3226e-01,  2.0082e-01,
         1.2834e-01, -2.1285e-01,  1.5065e-01, -3.9217e-08,  2.4985e-01,
         2.0454e-01,  5.4934e-01,  2.1021e-01,  2.2505e-01,  4.6484e-01,
         2.3888e-01,  2.0442e-01,  2.1546e-01,  6.6194e-01,  2.2755e-01,
         6.6069e-01,  2.0587e-01,  1.9292e-01,  1.1195e-01,  3.3785e-01,
         1.2393e-01,  4.1079e-02,  7.7150e-01,  2.6964e-01,  3.3347e-01,
         5.7908e-01,  1.5026e-01,  1.7534e-01,  1.9429e-01,  1.7248e-01,
         8.0577e-01,  2.3693e-01, -4.3369e-01,  8.4813e-01, -3.7857e-01,
         2.4787e-01,  1.8101e-01,  3.2949e-01, -2.8598e-01,  2.2717e-01,
         2.6168e-01,  5.7609e-02, -5.0320e-01,  1.5704e-01,  1.7890e-01,
         2.8114e-01,  4.2167e-01, -9.7650e-02, -3.1231e-01, -2.5637e-02,
         8.8566e-02,  1.8052e-01,  8.3045e-01,  2.5015e-01])), 
('bn1.running_mean', tensor([ 2.8781e-02,  1.0830e-01,  2.6812e-01, -4.7955e-02, -2.7350e-02,
        -1.2350e-02, -2.8534e-02,  3.8390e-02,  8.6643e-03,  1.1076e-01,
        -1.6231e-02, -7.1499e-01,  5.7644e-02, -5.1895e-07, -1.9860e-02,
         6.5988e-03,  4.9869e-01, -3.4726e-02, -2.2373e-02, -6.4198e-01,
         3.3326e-02,  6.5970e-02,  3.1869e-02,  3.1863e-01,  3.7692e-02,
         4.9075e-01,  3.0402e-02, -6.5330e-02, -2.4589e-02,  4.3018e-01,
        -6.3207e-02,  3.6987e-02, -7.9438e-01,  3.7037e-02,  8.1242e-01,
        -8.8931e-01, -3.4412e-02, -1.6578e-01, -1.8018e-02, -2.7667e-02,
        -1.3835e+00,  7.8008e-02, -7.0342e-01,  3.4551e-01,  5.7252e-01,
         4.5663e-02,  5.2766e-02,  2.8974e-01, -3.4401e-01,  1.6897e-02,
         9.7269e-02, -2.1634e-02,  7.9793e-01,  1.7612e-02, -3.2805e-03,
        -1.7782e-01, -1.4005e-01,  4.1215e-02,  7.2888e-01, -2.2417e-01,
         1.9287e-03,  8.7772e-02,  1.3144e+00, -3.8825e-02])), 
('bn1.running_var', tensor([ 5.0796e-01,  1.4441e+00,  3.3001e+00,  3.3098e+00,  1.3029e-01,
         3.3023e+00,  1.2143e-01,  2.5986e-01,  8.9925e-02,  2.9480e+00,
         1.3752e-01,  2.1341e+00,  6.9679e-02,  2.7234e-12,  2.4457e-02,
         6.9063e-02,  1.1395e+00,  8.0611e-02,  2.1984e-02,  2.6701e+00,
         5.6415e-02,  2.1792e-01,  1.0816e-01,  9.8851e-01,  3.0843e-01,
         2.9959e+00,  5.4037e-02,  1.7887e-01,  2.8518e-02,  1.8343e+00,
         7.0009e-01,  2.9475e-02,  1.1048e+01,  7.5987e-03,  2.6686e+00,
         5.0308e+00,  2.8717e+00,  1.7434e+00,  3.8133e-01,  1.3055e-01,
         8.6697e+00,  3.9596e-02,  2.3990e+00,  3.7014e+00,  6.9698e+00,
         1.2682e-01,  1.4923e-01,  1.5581e+00,  1.1554e+00,  2.0051e-02,
         1.3014e-01,  9.9781e-01,  3.6349e+00,  2.4568e-01,  1.2094e-01,
         7.6329e-01,  7.9295e-01,  1.5916e-01,  3.8380e+00,  3.2014e-01,
         3.4269e-01,  3.3512e-01,  8.0546e+00,  2.4255e-02])), ('layer1.0.conv1.weight', tensor([[[[ 3.5144e-03]],

         [[ 3.9855e-02]],

         [[-2.4795e-02]],

         ...,
 
('layer1.0.bn1.weight', tensor([ 2.1341e-01,  1.8848e-01,  1.4136e-01,  1.5273e-01,  1.3220e-01,
         1.8735e-01,  1.4475e-01,  4.5110e-08,  1.5993e-01,  1.4946e-01,
         2.3499e-01,  1.8315e-01,  1.8516e-01,  1.4933e-01,  1.3090e-01,
         1.0634e-01,  3.7487e-01,  1.2644e-01,  3.1895e-01,  2.7160e-01,
         2.5810e-01,  2.9458e-01,  1.8395e-01,  2.1088e-08,  3.3313e-01,
         2.0461e-01,  3.0399e-01,  1.1805e-08,  1.4977e-01,  1.5719e-01,
         1.4011e-01,  1.4900e-01,  1.2438e-01,  1.8786e-01,  1.4257e-01,
         3.4828e-01,  1.5038e-01,  3.0034e-01,  2.5925e-01,  1.0711e-01,
         2.6875e-01,  1.3552e-01,  1.1822e-01,  1.1189e-01,  2.8736e-01,
         3.2637e-01,  1.4781e-01,  2.3105e-01,  3.3638e-01,  2.8808e-01,
         1.2319e-01,  3.0763e-01,  1.1846e-01,  1.3137e-01,  2.0671e-01,
         1.5787e-01,  2.6574e-08,  2.0467e-01,  2.8797e-08,  1.8284e-01,
         3.0180e-01,  1.7401e-01,  2.8438e-01,  2.3715e-01])), 

('layer1.0.bn1.bias', tensor([ 4.3266e-01,  4.6854e-02, -8.0134e-02,  7.3302e-02,  2.7970e-01,
        -7.8047e-03,  9.4087e-02, -1.0086e-07, -1.4034e-01, -5.1599e-02,
         4.4470e-02,  2.1814e-01,  4.0718e-02,  1.1979e-01,  1.4432e-01,
         1.3672e-01, -1.1168e-01,  1.4774e-01, -1.2879e-01, -5.3147e-02,
        -3.3920e-02, -2.0600e-02,  6.2783e-02, -6.5736e-08, -7.1213e-02,
         6.9510e-02, -1.3264e-01, -6.4411e-08, -2.8908e-02,  9.4164e-02,
         2.4790e-01, -8.2850e-02, -2.8872e-02, -1.7086e-01,  9.9522e-02,
        -1.1357e-01,  1.9770e-01,  1.4800e-02, -7.0896e-02,  1.0722e-01,
         1.2536e-02, -3.6633e-02,  1.4959e-01,  1.0533e-01,  2.0933e-02,
        -1.0502e-01, -4.8848e-02,  4.9007e-01, -1.4755e-01, -1.0900e-01,
         1.9815e-02, -7.0964e-02, -4.6543e-02,  1.0874e-01, -2.7878e-01,
         4.4500e-03, -7.7156e-08,  7.5060e-02, -8.4474e-08,  2.2533e-01,
        -7.1593e-02, -1.5823e-01, -3.4459e-02,  5.2894e-01])), 

('layer1.0.bn1.running_mean', tensor([-6.0619e-01, -3.5467e-01,  2.4651e-01, -2.5210e-01, -7.6892e-02,
        -3.3654e-01, -1.0111e-01, -1.7881e-08,  2.1631e-01, -2.8016e-01,
        -3.1948e-01,  1.1134e+00, -1.1791e-01, -2.0125e-01, -3.2957e-01,
        -2.6431e-02, -3.4833e-01,  7.1402e-01, -2.7727e-01, -2.7576e-01,
        -1.7791e-01, -1.1054e-01, -1.5952e-01, -5.6052e-45, -3.6867e-01,
        -1.7413e-01, -2.6344e-01,  4.3125e-09, -2.3616e-01, -3.0546e-01,
        -1.8908e-02,  2.2109e-01,  1.1146e-02, -1.4291e-01, -3.0156e-01,
        -4.4344e-01, -2.2829e-01, -2.0861e-01, -2.2197e-01,  3.1603e-01,
        -1.1507e-01, -1.3784e-01, -2.9271e-01, -4.8246e-01, -1.5741e-01,
        -2.6682e-01, -3.8136e-01, -3.1360e-01, -1.9755e-01, -4.1116e-01,
        -2.8717e-02, -3.0186e-01,  8.8766e-02, -3.3887e-01, -5.9848e-02,
        -6.4817e-01,  1.2924e-09, -2.2738e-01, -5.6052e-45,  1.0252e+00,
        -9.3871e-02, -1.4969e-02, -4.0218e-01, -1.3630e-01])),