PyTorch學習(7)—儲存和載入訓練結果
本篇部落格主要介紹如何在PyTorch中儲存和載入模型訓練的結果。
對訓練結果進行儲存,有兩種方式,一種是儲存整個網路,另一種是儲存訓練好的引數,相對而言,第二種方式具有更高的效率。
下面是示例程式碼:
import torch from torch.autograd import Variable import matplotlib.pyplot as plt # 生成假資料 # torch.unsqueeze() 的作用是將一維變二維,torch只能處理二維的資料 x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # x data (tensor), shape(100, 1) # 0.2 * torch.rand(x.size())增加噪點 y = x.pow(2) + 0.2 + 0.2 * torch.rand(x.size()) # 將Tensor轉換為torch x, y = Variable(x, requires_grad=False), Variable(y, requires_grad=False) # 儲存神經網路 def save(): # 搭建神經網路 net1 = torch.nn.Sequential( torch.nn.Linear(1, 10), torch.nn.ReLU(), torch.nn.Linear(10, 1) ) # 優化器:隨機梯度下降 optimizer = torch.optim.SGD(net1.parameters(), lr=0.5) # 損失函式:均方差 loss_func = torch.nn.MSELoss() # 訓練100步 for i in range(100): prediction = net1(x) # 求誤差 loss = loss_func(prediction, y) # 優化 optimizer.zero_grad() # 反向傳播 loss.backward() optimizer.step() # 繪圖 plt.figure(1, figsize=(10, 3)) plt.subplot(131) plt.title('Net1') plt.scatter(x.data.numpy(), y.data.numpy()) plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) # 儲存 torch.save(net1, 'net.pkl') # save entire net torch.save(net1.state_dict(), 'net_params.pkl') # save parameters # 提取神經網路:儲存整個網路的形式 def restore_net(): net2 = torch.load('net.pkl') prediction = net2(x) # 繪圖 plt.subplot(132) plt.title('Net2') plt.scatter(x.data.numpy(), y.data.numpy()) plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) # 提取神經網路:儲存網路引數的方式 # 這種方式速度更快 def restore_params(): net3 = torch.nn.Sequential( torch.nn.Linear(1, 10), torch.nn.ReLU(), torch.nn.Linear(10, 1) ) net3.load_state_dict(torch.load('net_params.pkl')) prediction = net3(x) # 繪圖 plt.subplot(133) plt.title('Net3') plt.scatter(x.data.numpy(), y.data.numpy()) plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) plt.show() save() restore_net() restore_params()
執行結果:
相關推薦
PyTorch學習(7)—儲存和載入訓練結果
本篇部落格主要介紹如何在PyTorch中儲存和載入模型訓練的結果。 對訓練結果進行儲存,有兩種方式,一種是儲存整個網路,另一種是儲存訓練好的引數,相對而言,第二種方式具有更高的效率。 下面是示例程式碼: import torch from torch.autograd
kafka學習(7)生產者和消費者程式碼
首先,我們開啟kafka的api頁面,裡面都有詳細的樣例 http://kafka.apache.org/0100/javado
QT和opencv學習(二)opencv的載入、顯示、修改、儲存影象
載入影象(用cv::imread) imread功能是載入影象檔案成為一個Mat物件,其中第一個引數表示影象檔名稱 第二個引數,表示載入的影象是什麼型別,支援常見的三個引數值 IMREAD_UNCHANGED (<0) 表示載入原圖,不做任何改變
R語言學習(7)字符串和因子
const sprint 水平 tostring 大小 pow 個數 end paste 字符串和因子 1.字符串 創建字符串 > c("Hello","World")[1] "Hello" "World" paste( ) 函數連接字符串 >
微信小程式學習(18) —— 上拉載入和下拉重新整理
在微信小程式上實現下拉重新整理、上拉載入的效果 使用系統提供的onPullDownRefresh、onReachBottom這2個事件, 前提需要在app.json或page.json配置檔案中設定,才能使用。 app.json是全應用的頁面都可以使用該事件
python學習(7):python爬蟲之爬取動態載入的圖片,以百度圖片為例
前言: 前面我們爬取圖片的網站都是靜態的,在頁面中右鍵檢視原始碼就能看到網頁中圖片的位置。這樣我們用requests庫得到頁面原始碼後,再用bs4庫解析標籤即可儲存圖片到本地。 當我們在看百度圖片時,右鍵–檢查–Elements,點選箭頭,再用箭頭點選圖片時
Django入門學習(7)——自定義管理器和模型類的建立方法
自定義管理器的目的1:更改查詢集# -*- coding:utf-8 -*- from django.db import models class BookInfoManager(models.Manager): def get_queryset(self):
《構建之法》學習(7)——MSF
發現 解決方案 msf 我們 基本原則 無法 strong 出了 微軟 《構建之法》學習(7)——MSF 1.MSF簡史 微軟解決方案框架,也就是微軟推薦的軟件開發方法 2.MSF基本原則 推動信息共享與溝通 所有信息都保留並公開,討論要包括所有
Java學習(7):同步問題之生產者與消費者的問題
con runnable pop pre 標記 this auth style about 生產者生產饅頭,消費者消費饅頭。一個籃子,生產者往籃子中放饅頭,消費者從籃子中取饅頭。 /** * 這是一個籃子類 * * @author xcx * @time 2017
Guice 學習(七)常量和屬性的註入( Constant and Property Inject)
-a ret roc build ann class google mes ota 1、常量註入方式 package com.guice.ConstantInjectDemo; import com.google.inject.Binder; i
express學習(三)—— cookie和session
aaa 獲取 不知道 cookies htm 服務器 字符串 lis dom express學習(三)—— cookie和session cookie存在瀏覽器中,最大只能保存4K數據,不安全 session存在服務器中,不能獨立(先讀取cookie再讀取sessio
機器學習(七)—Adaboost 和 梯度提升樹GBDT
獲得 決策樹 info gin 否則 它的 均方差 但是 ont 1、Adaboost算法原理,優缺點: 理論上任何學習器都可以用於Adaboost.但一般來說,使用最廣泛的Adaboost弱學習器是決策樹和神經網絡。對於決策樹,Adaboost分類用了CART分類樹,
Java SpringMVC框架學習(二)httpServeltRequest和Model傳值的區別
urn ont ppi mode array style att 區別 () 為什麽大多程序在controller中給jsp傳值時使用model.addAttribute()而不使用httpServeletRequest.setAttribute()? 事實上model數
ES6知識整理(7)--Set和Map數據結構
ora ear 踏實 9.png 叠代 數據 edi KS 返回鍵 (文章會同步到博客園,技術類文章還是該讓搜索引擎察覺比較好)Set構造函數初始化一個值不重復的數組,適合做數組去重。2種數組去重的方法:這裏再說下Array.from(),表示以一個類數組||可叠代對象,創
Python語言程式設計基礎(7)—— 檔案和資料格式化
返回字串 file = input() #返回字串 fo = open(file,"r").read(6) print(fo) 返回列表形式 file = input() fo = open(file,"r") #print(fo) #返回列表形式 pr
pytorch筆記02)模型的儲存和載入
儲存和載入整個模型 torch.save(model_object, 'model.pkl') model = torch.load('model.pkl') 僅儲存和載入模型引數(推薦使用,需要提前手動構建模型) torch.save(model_object.state_
HTML的學習(7)
表格裡的一些照片 現在已經知道關於表格的一些東西、但是表格相比格式化文字而言其實有另外的用處。您可以使用一個表格來建立照片格式化的精美網路。 以後會進行css的學習,進行課程和專案的訓練。 練習題: <!DOCTYPE html> <html> <
ES6學習(一)---let和const用法
1.let用法 (1)存在作用域,即let宣告函式會在花括號中執行 (2)es6規定暫時性死區,暫時性死區通俗的來講就是一個區塊中存在let和const宣告的變數,那麼該區塊會形成封閉作用域,在let和const宣告之前使用該變數都會報錯。 //例子1 { console.lo
Java中io流的學習(九)ByteArrayInputStream和ByteArrayOutputStream
ByteArrayInputStream(記憶體輸入流)繼承於InputStream,ByteArrayOutputStream(記憶體輸出流)繼承於OutputStream。記憶體流是關不掉的,一般用來存放一些臨時性的資料,理論值是記憶體大小。 常用的方法是:read(),一系列read方法,
Python學習(7)——面向物件高階編輯
1、使用__slots__ (1)可以嘗試給例項繫結一個方法: def set_age(self, age): self.age = age from types import MethodType s.set_age = MethodType(set_age, s) # 給例項