1. 程式人生 > >PyTorch學習(7)—儲存和載入訓練結果

PyTorch學習(7)—儲存和載入訓練結果

本篇部落格主要介紹如何在PyTorch中儲存和載入模型訓練的結果。

對訓練結果進行儲存,有兩種方式,一種是儲存整個網路,另一種是儲存訓練好的引數,相對而言,第二種方式具有更高的效率。

下面是示例程式碼:

import torch
from torch.autograd import Variable
import matplotlib.pyplot as plt


# 生成假資料
# torch.unsqueeze() 的作用是將一維變二維,torch只能處理二維的資料
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)  # x data (tensor), shape(100, 1)
# 0.2 * torch.rand(x.size())增加噪點
y = x.pow(2) + 0.2 + 0.2 * torch.rand(x.size())

# 將Tensor轉換為torch
x, y = Variable(x, requires_grad=False), Variable(y, requires_grad=False)

# 儲存神經網路


def save():
    # 搭建神經網路
    net1 = torch.nn.Sequential(
        torch.nn.Linear(1, 10),
        torch.nn.ReLU(),
        torch.nn.Linear(10, 1)
    )

    # 優化器:隨機梯度下降
    optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)

    # 損失函式:均方差
    loss_func = torch.nn.MSELoss()

    # 訓練100步
    for i in range(100):
        prediction = net1(x)
        # 求誤差
        loss = loss_func(prediction, y)
        # 優化
        optimizer.zero_grad()
        # 反向傳播
        loss.backward()
        optimizer.step()
    # 繪圖
    plt.figure(1, figsize=(10, 3))
    plt.subplot(131)
    plt.title('Net1')
    plt.scatter(x.data.numpy(), y.data.numpy())
    plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
    # 儲存
    torch.save(net1, 'net.pkl')  # save entire net
    torch.save(net1.state_dict(), 'net_params.pkl')   # save parameters


# 提取神經網路:儲存整個網路的形式
def restore_net():
    net2 = torch.load('net.pkl')
    prediction = net2(x)
    # 繪圖
    plt.subplot(132)
    plt.title('Net2')
    plt.scatter(x.data.numpy(), y.data.numpy())
    plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)


# 提取神經網路:儲存網路引數的方式
# 這種方式速度更快
def restore_params():
    net3 = torch.nn.Sequential(
        torch.nn.Linear(1, 10),
        torch.nn.ReLU(),
        torch.nn.Linear(10, 1)
    )
    net3.load_state_dict(torch.load('net_params.pkl'))
    prediction = net3(x)
    # 繪圖
    plt.subplot(133)
    plt.title('Net3')
    plt.scatter(x.data.numpy(), y.data.numpy())
    plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
    plt.show()

save()

restore_net()

restore_params()

執行結果:

相關推薦

PyTorch學習7儲存載入訓練結果

本篇部落格主要介紹如何在PyTorch中儲存和載入模型訓練的結果。 對訓練結果進行儲存,有兩種方式,一種是儲存整個網路,另一種是儲存訓練好的引數,相對而言,第二種方式具有更高的效率。 下面是示例程式碼: import torch from torch.autograd

kafka學習7生產者消費者程式碼

首先,我們開啟kafka的api頁面,裡面都有詳細的樣例 http://kafka.apache.org/0100/javado

QTopencv學習opencv的載入、顯示、修改、儲存影象

載入影象(用cv::imread) imread功能是載入影象檔案成為一個Mat物件,其中第一個引數表示影象檔名稱 第二個引數,表示載入的影象是什麼型別,支援常見的三個引數值 IMREAD_UNCHANGED (<0) 表示載入原圖,不做任何改變

R語言學習7字符串因子

const sprint 水平 tostring 大小 pow 個數 end paste 字符串和因子 1.字符串   創建字符串 > c("Hello","World")[1] "Hello" "World"   paste( ) 函數連接字符串 >

微信小程式學習18 —— 上拉載入下拉重新整理

在微信小程式上實現下拉重新整理、上拉載入的效果 使用系統提供的onPullDownRefresh、onReachBottom這2個事件, 前提需要在app.json或page.json配置檔案中設定,才能使用。 app.json是全應用的頁面都可以使用該事件

python學習7:python爬蟲之爬取動態載入的圖片,以百度圖片為例

前言: 前面我們爬取圖片的網站都是靜態的,在頁面中右鍵檢視原始碼就能看到網頁中圖片的位置。這樣我們用requests庫得到頁面原始碼後,再用bs4庫解析標籤即可儲存圖片到本地。 當我們在看百度圖片時,右鍵–檢查–Elements,點選箭頭,再用箭頭點選圖片時

Django入門學習7——自定義管理器模型類的建立方法

自定義管理器的目的1:更改查詢集# -*- coding:utf-8 -*- from django.db import models class BookInfoManager(models.Manager): def get_queryset(self):

《構建之法》學習7——MSF

發現 解決方案 msf 我們 基本原則 無法 strong 出了 微軟 《構建之法》學習(7)——MSF 1.MSF簡史   微軟解決方案框架,也就是微軟推薦的軟件開發方法 2.MSF基本原則   推動信息共享與溝通   所有信息都保留並公開,討論要包括所有

Java學習7:同步問題之生產者與消費者的問題

con runnable pop pre 標記 this auth style about 生產者生產饅頭,消費者消費饅頭。一個籃子,生產者往籃子中放饅頭,消費者從籃子中取饅頭。 /** * 這是一個籃子類 * * @author xcx * @time 2017

Guice 學習常量屬性的註入 Constant and Property Inject

-a ret roc build ann class google mes ota 1、常量註入方式 package com.guice.ConstantInjectDemo; import com.google.inject.Binder; i

express學習—— cookiesession

aaa 獲取 不知道 cookies htm 服務器 字符串 lis dom express學習(三)—— cookie和session cookie存在瀏覽器中,最大只能保存4K數據,不安全 session存在服務器中,不能獨立(先讀取cookie再讀取sessio

機器學習—Adaboost 梯度提升樹GBDT

獲得 決策樹 info gin 否則 它的 均方差 但是 ont 1、Adaboost算法原理,優缺點:   理論上任何學習器都可以用於Adaboost.但一般來說,使用最廣泛的Adaboost弱學習器是決策樹和神經網絡。對於決策樹,Adaboost分類用了CART分類樹,

Java SpringMVC框架學習httpServeltRequestModel傳值的區別

urn ont ppi mode array style att 區別 () 為什麽大多程序在controller中給jsp傳值時使用model.addAttribute()而不使用httpServeletRequest.setAttribute()? 事實上model數

ES6知識整理7--SetMap數據結構

ora ear 踏實 9.png 叠代 數據 edi KS 返回鍵 (文章會同步到博客園,技術類文章還是該讓搜索引擎察覺比較好)Set構造函數初始化一個值不重復的數組,適合做數組去重。2種數組去重的方法:這裏再說下Array.from(),表示以一個類數組||可叠代對象,創

Python語言程式設計基礎7—— 檔案資料格式化

返回字串 file = input() #返回字串 fo = open(file,"r").read(6) print(fo)   返回列表形式 file = input() fo = open(file,"r") #print(fo) #返回列表形式 pr

pytorch筆記02)模型的儲存載入

儲存和載入整個模型 torch.save(model_object, 'model.pkl') model = torch.load('model.pkl') 僅儲存和載入模型引數(推薦使用,需要提前手動構建模型) torch.save(model_object.state_

HTML的學習7

表格裡的一些照片 現在已經知道關於表格的一些東西、但是表格相比格式化文字而言其實有另外的用處。您可以使用一個表格來建立照片格式化的精美網路。 以後會進行css的學習,進行課程和專案的訓練。 練習題: <!DOCTYPE html> <html> <

ES6學習---letconst用法

1.let用法 (1)存在作用域,即let宣告函式會在花括號中執行 (2)es6規定暫時性死區,暫時性死區通俗的來講就是一個區塊中存在let和const宣告的變數,那麼該區塊會形成封閉作用域,在let和const宣告之前使用該變數都會報錯。 //例子1 { console.lo

Java中io流的學習ByteArrayInputStreamByteArrayOutputStream

ByteArrayInputStream(記憶體輸入流)繼承於InputStream,ByteArrayOutputStream(記憶體輸出流)繼承於OutputStream。記憶體流是關不掉的,一般用來存放一些臨時性的資料,理論值是記憶體大小。 常用的方法是:read(),一系列read方法,

Python學習7——面向物件高階編輯

1、使用__slots__ (1)可以嘗試給例項繫結一個方法: def set_age(self, age): self.age = age from types import MethodType s.set_age = MethodType(set_age, s) # 給例項