1. 程式人生 > >Pytorch model saving and loading 模型保存和讀取

Pytorch model saving and loading 模型保存和讀取

save pro pat args .py ams str comm pre

It is really useful to save and reload the model and its parameters during or after training in deep learning.

Pytorch provides two methods to do so.

1. Only restore the parameters (recommended)

torch.save(the_model.state_dict(), PATH)    # save parameters to PATH

the_model = TheModelClass(*args, **kwargs)    # declare the_model as a object of TheModelClass
the_model.load_state_dict(torch.load(PATH))    # load parameters from PATH

2. Save all structure and parameters

torch.save(the_model, PATH)

the_model = torch.load(PATH)

3. Get parameters of certain layer

params=model.state_dict() 
for k,v in params.items():
    print(k)    # print the variable names in networks
print(params[‘conv1.weight‘])   #print conv1‘s weight
print(params[‘conv1.bias‘])   #print conv1‘s bias  

  

reference:http://www.pytorchtutorial.com/pytorch-note5-save-and-restore-models/

  

Pytorch model saving and loading 模型保存和讀取