1. 程式人生 > >torch學習筆記3.3:實現自定義模組(gpu)

torch學習筆記3.3:實現自定義模組(gpu)

在使用torch時,如果想自己實現一個層,則可以按照《torch學習筆記1:實現自定義層》 中的方法來實現。但是如果想要實現一個比較複雜的網路,往往需要自己實現多個層(或類),並且有時可能需要重寫其他模組中已有的函式來達到自己的目的,如果還是在nn模組中新增,會比較混亂,並且不利於本地git倉庫統一管理,這個時候,我們可以自己實現一個像nn一樣的模組,在程式碼中使用時 require即可。

我們來實現一個名為nxn的自定義模組,以及它的cuda版本cunxn模組,其中包含一個自定義的Hello類(lua實現),ReLU類(分別用CPU和GPU實現)。

由於篇幅原因,這裡把torch自定義模組的lua實現,cpu實現,gpu實現分別寫一篇文章,本文介紹cpu實現的ReLU類。

3 檔案說明

這裡介紹的都是cunxn資料夾裡面的。

CMakeLists.txt

可以參考torch自帶模組來寫,主要是cuda檔案的編譯和連結,需要注意的部分內容如下:

......
FIND_PACKAGE(CUDA 4.0 REQUIRED)

SET(src-cuda init.cu)

CUDA_ADD_LIBRARY(cunxn MODULE ${src-cuda})
TARGET_LINK_LIBRARIES(cunxn luaT THC TH)
IF(APPLE)
  SET_TARGET_PROPERTIES(cunxn PROPERTIES
    LINK_FLAGS "-undefined dynamic_lookup"
) ENDIF() ### Torch packages supposes libraries prefix is "lib" SET_TARGET_PROPERTIES(cunxn PROPERTIES PREFIX "lib" IMPORT_PREFIX "lib") INSTALL(TARGETS cunxn RUNTIME DESTINATION "${Torch_INSTALL_LUA_CPATH_SUBDIR}" LIBRARY DESTINATION "${Torch_INSTALL_LUA_CPATH_SUBDIR}") SET(luasrc init.lua) INSTALL
( FILES ${luasrc} DESTINATION "${Torch_INSTALL_LUA_PATH_SUBDIR}/cunxn") ADD_TORCH_PACKAGE(cunxn "" "${luasrc}")

cunxn-scm-1.rockspec

其中的build部分和其他rockspec檔案一樣

package = "cunxn"
version = "scm-1"

source = {
   url = "git://github.com/soumith/examplepackage.torch",
   tag = "master"
}

dependencies = {
   "torch >= 7.0",
   "cunn",
   "nn"
}

......

init.cu

同init.c的功能一樣,編譯時查詢要編譯的檔案,以及生成libcunxn:

#include "luaT.h"
#include "THC.h"
#include "THLogAdd.h" /* DEBUG: WTF */

#include <thrust/transform.h>
#include <thrust/reduce.h>
#include <thrust/transform_reduce.h>
#include <thrust/functional.h>
#include <thrust/device_ptr.h>

#include "ReLU.cu"



LUA_EXTERNC DLL_EXPORT int luaopen_libcunxn(lua_State *L);

int luaopen_libcunxn(lua_State *L)
{
  lua_newtable(L);

  cunxn_ReLU_init(L);

  return 1;
}

init.lua

require "cutorch"
require "nxn"
require "libcunxn"

ReLU.cu

cuda實現的ReLU

struct reluupdateOutput_functor
{
  __host__ __device__ float operator()(const float& input) const
  {
    return input > 0 ? input : 0;
  }
};

THCState* getCutorchState(lua_State* L)
{
    lua_getglobal(L, "cutorch");
    lua_getfield(L, -1, "getState");
    lua_call(L, 0, 1);
    THCState *state = (THCState*) lua_touserdata(L, -1);
    lua_pop(L, 2);
    return state;
} 

static int cunxn_ReLU_updateOutput(lua_State *L)
{
  printf("GPU version of ReLU updateOutput function\n");
  THCState *state = getCutorchState(L);
  THCudaTensor *input = (THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor");
  THCudaTensor *output = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "output", "torch.CudaTensor");
  long size = THCudaTensor_nElement(state, input);

  input = THCudaTensor_newContiguous(state, input);

  THCudaTensor_resizeAs(state, output, input);

  thrust::device_ptr<float> output_data(THCudaTensor_data(state, output));
  thrust::device_ptr<float> input_data(THCudaTensor_data(state, input));
  thrust::transform(input_data, input_data+size, output_data, reluupdateOutput_functor());

  THCudaTensor_free(state, input);
  return 1;
}

struct reluupdateGradInput_functor
{
  __host__ __device__ float operator()(const float& output, const float& gradOutput) const
  {
    return gradOutput * (output > 0 ? 1 : 0);
  }
};

static int cunxn_ReLU_updateGradInput(lua_State *L)
{
  printf("GPU version of ReLU updateGradInput function\n");
  THCState *state = getCutorchState(L);
  THCudaTensor *output = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "output", "torch.CudaTensor");
  THCudaTensor *gradOutput = (THCudaTensor*)luaT_checkudata(L, 3, "torch.CudaTensor");
  THCudaTensor *gradInput = (THCudaTensor*)luaT_getfieldcheckudata(L, 1, "gradInput", "torch.CudaTensor");
  long size = THCudaTensor_nElement(state, output);

  gradOutput = THCudaTensor_newContiguous(state, gradOutput);

  THCudaTensor_resizeAs(state, gradInput, output);

  thrust::device_ptr<float> output_data(THCudaTensor_data(state, output));
  thrust::device_ptr<float> gradOutput_data(THCudaTensor_data(state, gradOutput));
  thrust::device_ptr<float> gradInput_data(THCudaTensor_data(state, gradInput));
  thrust::transform(output_data, output_data+size, gradOutput_data, gradInput_data, reluupdateGradInput_functor());

  THCudaTensor_free(state, gradOutput);
  return 1;
}

static const struct luaL_Reg cunxn_ReLU__ [] = {
  {"ReLU_updateOutput", cunxn_ReLU_updateOutput},
  {"ReLU_updateGradInput", cunxn_ReLU_updateGradInput},
  {NULL, NULL}
};

static void cunxn_ReLU_init(lua_State *L)
{
  luaT_pushmetatable(L, "torch.CudaTensor");
  luaT_registeratname(L, cunxn_ReLU__, "nxn");
  lua_pop(L,1);
}