torch學習筆記1:實現自定義層
當我們要實現自己的一些idea時,torch自帶的模組和函式已經不能滿足,我們需要自己實現層(或者類),一般的做法是把自定義層加入到已有的torch模組中。
實現
lua實現
如果自定義層的功能可以通過呼叫torch中已有的函式實現,那就只需要用lua實現,torch的文件中也提供了簡單的說明。
現在我們來實現一個NewClass:
- 在torch目錄下(
torch/extra/nn/
)建立檔案NewClass.lua - 參考nn中其他lua檔案的結構寫好模板,在對應的函式中實現想要的功能
--建立新類,從nn.Module繼承
local NewClass, Parent = torch.class('nn.NewClass' , 'nn.Module')
--初始化操作
function NewClass:__init()
Parent.__init(self)
end
--前向傳播
function NewClass:updateOutput(input)
end
--反向傳播
function NewClass:updateGradInput(input, gradOutput)
end
--損失對引數的偏導,也就是殘差,如果該層沒有要學習的引數,則不需要寫這個函式
function NewClass:accGradParameters(input, gradOutput)
end
- 在nn的init.lua中末尾新增一句
require('nn.NewClass')
- 重新安裝nn模組
cd torch/extra/nn/
luarocks make rocks/nn-scm-1.rockspec
- 安裝成功後,在自己的程式碼中使用自定義的類了
require 'nn'
...
nn.NewClass()
...
CPU實現
如果通過torch的函式不能實現出需要的功能,那麼需要自己寫C程式實現核心功能,然後在NewClass.lua中呼叫。
- 在
torch/extra/nn/lib/THNN/generic/
目錄下建立檔案NewClass.c - 參考nn中已有的實現,在函式中實現需要的功能
...
void THNN_(NewClass_updateOutput)(
THNNState *state,
THTensor *input,
THTensor *output)
{
}
void THNN_(NewClass_updateGradInput)(
THNNState *state,
THTensor *input,
THTensor *gradOutput,
THTensor *gradInput)
{
}
...
- 宣告已實現的函式,在
torch/extra/nn/lib/THNN/generic/THNN.h
中新增
...
TH_API void THNN_(NewClass_updateOutput)(
THNNState *state,
THTensor *input,
THTensor *output);
TH_API void THNN_(NewClass_updateGradInput)(
THNNState *state,
THTensor *input,
THTensor *gradOutput,
THTensor *gradInput);
...
- 新增include,在
torch/extra/nn/lib/THNN/init.c
中,新增
#include "generic/NewClass.c"
#include "THGenerateFloatTypes.h"
- 在NewClass.lua中呼叫CPU版本的函式
...
function NewClass:updateOutput(input)
input.THNN.NewClass_updateOutput(
input:cdata(),
self.output:cdata()
)
return self.output
end
function NewClass:updateGradInput(input, gradOutput)
if self.gradInput then
input.THNN.NewClass_updateGradInput(
input:cdata(),
self.gradInput:cdata(),
gradOutput:cdata()
)
return self.gradInput
end
end
...
- 重新編譯安裝nn
cd torch/extra/nn/
luarocks make rocks/nn-scm-1.rockspec
- 安裝成功後,在自己的程式碼中使用自定義的類了
require 'nn'
...
nn.NewClass()
...
Cuda實現
如果想要進一步提升運算效率,需要自己寫一個Cuda版本的程式。
- 在
torch/extra/cunn/lib/THCUNN/
目錄下建立檔案NewClass.cu - 參考cunn中已有的函式,實現函式功能
...
void THNN_CudaNewClass_updateOutput(THCState *state, THCudaTensor *input, THCudaTensor *output)
{
}
void THNN_CudaNewClass_updateGradInput(THCState *state, THCudaTensor *input, THCudaTensor *gradOutput, THCudaTensor *gradInput)
{
}
...
- 宣告函式,在
torch/extra/cunn/lib/THCUNN/THCUNN.h
中新增:
TH_API void THNN_CudaNewClass_updateOutput(
THCState *state,
THCudaTensor *input,
THCudaTensor *output);
TH_API void THNN_CudaNewClass_updateGradInput(
THCState *state,
THCudaTensor *input,
THCudaTensor *gradOutput,
THCudaTensor *gradInput);
- 在NewClass.lua中呼叫GPU版本的函式,和CPU版本一樣,都通過THNN呼叫
- 重新編譯安裝cunn
cd torch/extra/cunn/
luarocks make rocks/cunn-scm-1.rockspec
- 安裝成功後,在自己的程式碼中使用自定義的類了
require 'cunn'
...
nn.NewClass()
...
測試
在torch/extra/nn/test.lua
和torch/extra/cunn/test.lua
中新增測試程式碼,可以用來測試NewClass的輸出是否正確,具體可參考已有的測試程式碼。
新增好後,執行th -lnn -e "nn.test{'NewClass'}"
即可測試。
相關推薦
torch學習筆記1:實現自定義層
當我們要實現自己的一些idea時,torch自帶的模組和函式已經不能滿足,我們需要自己實現層(或者類),一般的做法是把自定義層加入到已有的torch模組中。 實現 lua實現 如果自定義層的功能可以通過呼叫torch中已有的函式實現,那就只需要用l
tensorflow學習筆記(三):實現自編碼器
sea start ear var logs cos soft 編碼 red 黃文堅的tensorflow實戰一書中的第四章,講述了tensorflow實現多層感知機。Hiton早年提出過自編碼器的非監督學習算法,書中的代碼給出了一個隱藏層的神經網絡,本人擴展到了多層,改進
react native學習筆記24——Modal實現自定義彈出對話方塊
前言 上一篇文章介紹React Native系統提供的兩個彈出框的api——Alert與AlertIOS,Alert可以在雙平臺通用,但是隻能展示資訊量有限功能單一的文字對話方塊。AlertIOS比Alert稍微豐富一點,可以展示供使用者輸入的對話方塊,但只能
torch學習筆記3.2:實現自定義模組(cpu)
在使用torch時,如果想自己實現一個層,則可以按照《torch學習筆記1:實現自定義層》 中的方法來實現。但是如果想要實現一個比較複雜的網路,往往需要自己實現多個層(或類),並且有時可能需要重寫其他模組中已有的函式來達到自己的目的,如果還是在nn模組中新
torch學習筆記3.3:實現自定義模組(gpu)
在使用torch時,如果想自己實現一個層,則可以按照《torch學習筆記1:實現自定義層》 中的方法來實現。但是如果想要實現一個比較複雜的網路,往往需要自己實現多個層(或類),並且有時可能需要重寫其他模組中已有的函式來達到自己的目的,如果還是在nn模組中新
Python學習筆記1:簡單實現ssh客戶端和服務端
bsp dev bre 客戶端 break 基於 bin listen 客戶 實現基於python 3.6。 server端: 1 __author__ = "PyDev2018" 2 3 import socket,os 4 server = socket.s
美國高通 Snapdragon Neural Processing Engine SDK (SNPE) 系列 (1):使用者自定義層JNI實現
轉自:https://blog.csdn.net/guvcolie/article/details/77937786 Snapdragon Neural Processing Engine SDK是美國高通公司出品的神經網路處理引擎(SNP
Vue:學習筆記(七)-自定義指令
提醒 原帖完整收藏於IT老兵驛站,並會不斷更新。 前言 前面總結到了元件,對混入也進行了研究,不過感覺沒有啥需要總結的,就先總結指令吧,參考這裡,記錄筆記。 正文 簡介 全域性註冊 // 註冊一個全域性自定義指令 `v-focus` Vue.di
機器學習筆記1:機器學習定義與分類
機器學習定義與分類 Andrew Ng機器學習課程學習筆記1 定義 Arthur Samuel (1959) Machine Learning: Field of study that gives computers the ability to l
斯坦福機器學習筆記1:GDA高斯判別分析演算法的原理及matlab程式實現
ps:我本身沒有系統的學過matlab程式設計,所以有的方法,比如求均值用mean()函式之類的方法都是用很笨的方法實現的,所以有很多需要改進的地方,另外是自學實現的程式,可能有的地方我理解錯誤,如果有錯誤請提出來,大家一起學習,本人qq553566286 首先,本文用到的
ESP8266學習筆記2:實現ESP8266的局域網內通信
pro reg sad net nts 理解 模式 curl ont 上一篇熟悉了編譯下載操作。如今就以實例入手。project使用的是IOT_DEMO,據DEMO文檔能夠知道ESP8266初始工作模式為softAP+station共存的模式。於是這邊我們就先以soft
AngularJs學習筆記(4)——自定義指令
ref 告訴 ack 生命周期 .com bsp ctrl 參數變量 ng- 對指令的第一印象:它是一個自定義標簽! 先來看一個簡單的指令: <!doctype html> <html ng-app="myApp"> <head>
Effictive Java學習筆記1:創建和銷毀對象
安全 需要 () 函數 調用 bsp nbsp bean 成了 建議1:考慮用靜態工廠方法代替構造器 理由:1)靜態方法有名字啊,更容易懂和理解。構造方法重載容易讓人混淆,並不是好主意 2)靜態工廠方法可以不必每次調用時都創建一個新對象,而公共構造函數每次調用都會
golang學習筆記(1):安裝&helloworld
golang安裝:golang編譯器安裝過程比較簡單,也比較快,不同平臺下(win/linux/macos)都比較相似;https://dl.gocn.io/golang/1.9.2/go1.9.2.src.tar.gz 下載對應的系統版本的編譯器go的版本號由"." 分為3部分如當前的
寒假學習筆記1:結構化程序設計
控制流程 ram 循環 只有一個 嚴格 學習筆記 程序編寫 ont 部分 結構化程序設計(structured programming)是進行以模塊功能和處理過程設計為主的詳細設計的基本原則。 - 內容 主張使用順序、選擇、循環三種基本結構來嵌套連結成具有復雜層次的“結構
hibernate框架學習筆記1:搭建與測試
for this ble action 1.7 turn yiq targe cts hibernate框架屬於dao層,類似dbutils的作用,是一款ORM(對象關系映射)操作 使用hibernate框架好處是:操作數據庫不需要寫SQL語句,使用面向對象的方式完成
struts2框架學習筆記1:搭建測試
method lang app org char 示例 重要 type img Servlet是線程不安全的,Struts1是基於Servlet的框架 而Struts2是基於Filter的框架,解決了線程安全問題 因此Struts1和Struts2基本沒有關系,只是創造者取
Python學習筆記1:用戶登錄
\n win col lines %s courier class for ID 1 import getpass,sys 2 u=0 3 while u< 3: 4 user_name = input(‘Please input you
機器學習之路: tensorflow 自定義 損失函數
cond pre port var IV 學習 col float ria git: https://github.com/linyi0604/MachineLearning/tree/master/07_tensorflow/ 1 import tensor
Python3學習筆記1:變量和簡單數據類型
tle 小數點 per port 小數 指導 day this python 2018-09-16 17:22:11 變量聲明: 變量名 = ?? 如: 1 message = "HelloWorld" 2 message = 1 3 message =