1. 程式人生 > >torch學習筆記1:實現自定義層

torch學習筆記1:實現自定義層

當我們要實現自己的一些idea時,torch自帶的模組和函式已經不能滿足,我們需要自己實現層(或者類),一般的做法是把自定義層加入到已有的torch模組中。

實現

lua實現

如果自定義層的功能可以通過呼叫torch中已有的函式實現,那就只需要用lua實現,torch的文件中也提供了簡單的說明
現在我們來實現一個NewClass:

  • 在torch目錄下(torch/extra/nn/)建立檔案NewClass.lua
  • 參考nn中其他lua檔案的結構寫好模板,在對應的函式中實現想要的功能
--建立新類,從nn.Module繼承
local NewClass, Parent = torch.class('nn.NewClass'
, 'nn.Module') --初始化操作 function NewClass:__init() Parent.__init(self) end --前向傳播 function NewClass:updateOutput(input) end --反向傳播 function NewClass:updateGradInput(input, gradOutput) end --損失對引數的偏導,也就是殘差,如果該層沒有要學習的引數,則不需要寫這個函式 function NewClass:accGradParameters(input, gradOutput) end
  • 在nn的init.lua中末尾新增一句
require('nn.NewClass')
  • 重新安裝nn模組
cd torch/extra/nn/
luarocks make rocks/nn-scm-1.rockspec
  • 安裝成功後,在自己的程式碼中使用自定義的類了
require 'nn'
...
nn.NewClass()
...

CPU實現

如果通過torch的函式不能實現出需要的功能,那麼需要自己寫C程式實現核心功能,然後在NewClass.lua中呼叫。

  • torch/extra/nn/lib/THNN/generic/目錄下建立檔案NewClass.c
  • 參考nn中已有的實現,在函式中實現需要的功能
...
void THNN_(NewClass_updateOutput)(
          THNNState *state,
          THTensor *input,
          THTensor *output)
{

}

void THNN_(NewClass_updateGradInput)(
          THNNState *state,
          THTensor *input,
          THTensor *gradOutput,
          THTensor *gradInput)
{

}
...
  • 宣告已實現的函式,在torch/extra/nn/lib/THNN/generic/THNN.h中新增
...
TH_API void THNN_(NewClass_updateOutput)(
          THNNState *state,            
          THTensor *input,             
          THTensor *output);           
TH_API void THNN_(NewClass_updateGradInput)(
          THNNState *state,            
          THTensor *input,            
          THTensor *gradOutput,      
          THTensor *gradInput);    
...    
  • 新增include,在torch/extra/nn/lib/THNN/init.c中,新增
#include "generic/NewClass.c"
#include "THGenerateFloatTypes.h"
  • 在NewClass.lua中呼叫CPU版本的函式
...
function NewClass:updateOutput(input)
   input.THNN.NewClass_updateOutput(
      input:cdata(),
      self.output:cdata()
   )
   return self.output
end

function NewClass:updateGradInput(input, gradOutput)

   if self.gradInput then
      input.THNN.NewClass_updateGradInput(
         input:cdata(),
         self.gradInput:cdata(),
         gradOutput:cdata()
      )
      return self.gradInput
   end
end
...
  • 重新編譯安裝nn
cd torch/extra/nn/
luarocks make rocks/nn-scm-1.rockspec
  • 安裝成功後,在自己的程式碼中使用自定義的類了
require 'nn'
...
nn.NewClass()
...

Cuda實現

如果想要進一步提升運算效率,需要自己寫一個Cuda版本的程式。

  • torch/extra/cunn/lib/THCUNN/目錄下建立檔案NewClass.cu
  • 參考cunn中已有的函式,實現函式功能
...
void THNN_CudaNewClass_updateOutput(THCState *state, THCudaTensor *input, THCudaTensor *output)
{

}

void THNN_CudaNewClass_updateGradInput(THCState *state, THCudaTensor *input, THCudaTensor *gradOutput, THCudaTensor *gradInput)
{

}
...
  • 宣告函式,在torch/extra/cunn/lib/THCUNN/THCUNN.h中新增:
TH_API void THNN_CudaNewClass_updateOutput(
          THCState *state,
          THCudaTensor *input,
          THCudaTensor *output);
TH_API void THNN_CudaNewClass_updateGradInput(
          THCState *state,
          THCudaTensor *input,
          THCudaTensor *gradOutput,
          THCudaTensor *gradInput);
  • 在NewClass.lua中呼叫GPU版本的函式,和CPU版本一樣,都通過THNN呼叫
  • 重新編譯安裝cunn
cd torch/extra/cunn/
luarocks make rocks/cunn-scm-1.rockspec
  • 安裝成功後,在自己的程式碼中使用自定義的類了
require 'cunn'
...
nn.NewClass()
...

測試

torch/extra/nn/test.luatorch/extra/cunn/test.lua中新增測試程式碼,可以用來測試NewClass的輸出是否正確,具體可參考已有的測試程式碼。
新增好後,執行th -lnn -e "nn.test{'NewClass'}"即可測試。

相關推薦

torch學習筆記1實現定義

當我們要實現自己的一些idea時,torch自帶的模組和函式已經不能滿足,我們需要自己實現層(或者類),一般的做法是把自定義層加入到已有的torch模組中。 實現 lua實現 如果自定義層的功能可以通過呼叫torch中已有的函式實現,那就只需要用l

tensorflow學習筆記(三)實現編碼器

sea start ear var logs cos soft 編碼 red 黃文堅的tensorflow實戰一書中的第四章,講述了tensorflow實現多層感知機。Hiton早年提出過自編碼器的非監督學習算法,書中的代碼給出了一個隱藏層的神經網絡,本人擴展到了多層,改進

react native學習筆記24——Modal實現定義彈出對話方塊

前言 上一篇文章介紹React Native系統提供的兩個彈出框的api——Alert與AlertIOS,Alert可以在雙平臺通用,但是隻能展示資訊量有限功能單一的文字對話方塊。AlertIOS比Alert稍微豐富一點,可以展示供使用者輸入的對話方塊,但只能

torch學習筆記3.2實現定義模組(cpu)

在使用torch時,如果想自己實現一個層,則可以按照《torch學習筆記1:實現自定義層》 中的方法來實現。但是如果想要實現一個比較複雜的網路,往往需要自己實現多個層(或類),並且有時可能需要重寫其他模組中已有的函式來達到自己的目的,如果還是在nn模組中新

torch學習筆記3.3實現定義模組(gpu)

在使用torch時,如果想自己實現一個層,則可以按照《torch學習筆記1:實現自定義層》 中的方法來實現。但是如果想要實現一個比較複雜的網路,往往需要自己實現多個層(或類),並且有時可能需要重寫其他模組中已有的函式來達到自己的目的,如果還是在nn模組中新

Python學習筆記1簡單實現ssh客戶端和服務端

bsp dev bre 客戶端 break 基於 bin listen 客戶 實現基於python 3.6。 server端: 1 __author__ = "PyDev2018" 2 3 import socket,os 4 server = socket.s

美國高通 Snapdragon Neural Processing Engine SDK (SNPE) 系列 (1使用者定義JNI實現

轉自:https://blog.csdn.net/guvcolie/article/details/77937786         Snapdragon Neural Processing Engine SDK是美國高通公司出品的神經網路處理引擎(SNP

Vue學習筆記(七)-定義指令

提醒 原帖完整收藏於IT老兵驛站,並會不斷更新。 前言 前面總結到了元件,對混入也進行了研究,不過感覺沒有啥需要總結的,就先總結指令吧,參考這裡,記錄筆記。 正文 簡介 全域性註冊 // 註冊一個全域性自定義指令 `v-focus` Vue.di

機器學習筆記1機器學習定義與分類

機器學習定義與分類 Andrew Ng機器學習課程學習筆記1 定義 Arthur Samuel (1959) Machine Learning: Field of study that gives computers the ability to l

斯坦福機器學習筆記1GDA高斯判別分析演算法的原理及matlab程式實現

ps:我本身沒有系統的學過matlab程式設計,所以有的方法,比如求均值用mean()函式之類的方法都是用很笨的方法實現的,所以有很多需要改進的地方,另外是自學實現的程式,可能有的地方我理解錯誤,如果有錯誤請提出來,大家一起學習,本人qq553566286 首先,本文用到的

ESP8266學習筆記2實現ESP8266的局域網內通信

pro reg sad net nts 理解 模式 curl ont 上一篇熟悉了編譯下載操作。如今就以實例入手。project使用的是IOT_DEMO,據DEMO文檔能夠知道ESP8266初始工作模式為softAP+station共存的模式。於是這邊我們就先以soft

AngularJs學習筆記(4)——定義指令

ref 告訴 ack 生命周期 .com bsp ctrl 參數變量 ng- 對指令的第一印象:它是一個自定義標簽! 先來看一個簡單的指令: <!doctype html> <html ng-app="myApp"> <head>

Effictive Java學習筆記1創建和銷毀對象

安全 需要 () 函數 調用 bsp nbsp bean 成了 建議1:考慮用靜態工廠方法代替構造器 理由:1)靜態方法有名字啊,更容易懂和理解。構造方法重載容易讓人混淆,並不是好主意    2)靜態工廠方法可以不必每次調用時都創建一個新對象,而公共構造函數每次調用都會

golang學習筆記(1)安裝&helloworld

golang安裝:golang編譯器安裝過程比較簡單,也比較快,不同平臺下(win/linux/macos)都比較相似;https://dl.gocn.io/golang/1.9.2/go1.9.2.src.tar.gz 下載對應的系統版本的編譯器go的版本號由"." 分為3部分如當前的

寒假學習筆記1結構化程序設計

控制流程 ram 循環 只有一個 嚴格 學習筆記 程序編寫 ont 部分 結構化程序設計(structured programming)是進行以模塊功能和處理過程設計為主的詳細設計的基本原則。 - 內容 主張使用順序、選擇、循環三種基本結構來嵌套連結成具有復雜層次的“結構

hibernate框架學習筆記1搭建與測試

for this ble action 1.7 turn yiq targe cts hibernate框架屬於dao層,類似dbutils的作用,是一款ORM(對象關系映射)操作 使用hibernate框架好處是:操作數據庫不需要寫SQL語句,使用面向對象的方式完成

struts2框架學習筆記1搭建測試

method lang app org char 示例 重要 type img Servlet是線程不安全的,Struts1是基於Servlet的框架 而Struts2是基於Filter的框架,解決了線程安全問題 因此Struts1和Struts2基本沒有關系,只是創造者取

Python學習筆記1用戶登錄

\n win col lines %s courier class for ID 1 import getpass,sys 2 u=0 3 while u< 3: 4 user_name = input(‘Please input you

機器學習之路 tensorflow 定義 損失函數

cond pre port var IV 學習 col float ria git: https://github.com/linyi0604/MachineLearning/tree/master/07_tensorflow/ 1 import tensor

Python3學習筆記1變量和簡單數據類型

tle 小數點 per port 小數 指導 day this python 2018-09-16 17:22:11 變量聲明:   變量名 = ?? 如: 1 message = "HelloWorld" 2 message = 1 3 message =