torch学習ノート1:カスタムレイヤの実装
私たちが自分のideaを実装する場合、torchが持っているモジュールと関数はもう満足できません.私たちは自分でレイヤ(またはクラス)を実装する必要があります.一般的には、カスタムレイヤを既存のtorchモジュールに追加することです.
インプリメンテーション
lua実装
カスタムレイヤの機能がtorchに既存の関数を呼び出すことで実現できる場合は、luaで実現するだけで、torchのドキュメントにも簡単な説明があります.新しいクラスを実現します torchディレクトリの下( を作成 nnの他のluaファイルの構造を参照してテンプレートを書き、対応する関数で所望の機能 を実現する. nnのinit.luaの末尾に を追加 nnモジュール を再インストールのインストールに成功した後、自分のコードにカスタムクラス を使用しました.
CPU実装
torchの関数で必要な機能が実現できない場合は、Cプログラムを自分で書いてコア機能を実現し、NewClassで実現する必要がある.luaで呼び出されます. nnに既存の実装を参照し、関数に必要な機能 を実装する.は実装された関数を宣言し、 を追加する. include、 を追加ニュークラスです.luaでCPUバージョンを呼び出す関数 再コンパイルインストールnn のインストールに成功した後、自分のコードにカスタムクラス を使用しました.
Cuda実装
演算効率をさらに向上させるには、自分でCudaバージョンのプログラムを書く必要があります. cunnに既存の関数を参照し、関数機能 を実現する宣言関数、 ニュークラスです.luaでGPUバージョンを呼び出す関数は,CPUバージョンと同様にTHNNで を呼び出す.再コンパイルインストールcunn のインストールに成功した後、自分のコードにカスタムクラス を使用しました.
テスト
インプリメンテーション
lua実装
カスタムレイヤの機能がtorchに既存の関数を呼び出すことで実現できる場合は、luaで実現するだけで、torchのドキュメントにも簡単な説明があります.新しいクラスを実現します
torch/extra/nn/
)にファイルNewClass.lua -- , nn.Module
local NewClass, Parent = torch.class('nn.NewClass', 'nn.Module')
--
function NewClass:__init()
Parent.__init(self)
end
--
function NewClass:updateOutput(input)
end
--
function NewClass:updateGradInput(input, gradOutput)
end
-- , , ,
function NewClass:accGradParameters(input, gradOutput)
end
require('nn.NewClass')
cd torch/extra/nn/
luarocks make rocks/nn-scm-1.rockspec
require 'nn'
...
nn.NewClass()
...
CPU実装
torchの関数で必要な機能が実現できない場合は、Cプログラムを自分で書いてコア機能を実現し、NewClassで実現する必要がある.luaで呼び出されます.
torch/extra/nn/lib/THNN/generic/
ディレクトリの下にファイルNewClassを作成する.c ...
void THNN_(NewClass_updateOutput)(
THNNState *state,
THTensor *input,
THTensor *output)
{
}
void THNN_(NewClass_updateGradInput)(
THNNState *state,
THTensor *input,
THTensor *gradOutput,
THTensor *gradInput)
{
}
...
torch/extra/nn/lib/THNN/generic/THNN.h
に...
TH_API void THNN_(NewClass_updateOutput)(
THNNState *state,
THTensor *input,
THTensor *output);
TH_API void THNN_(NewClass_updateGradInput)(
THNNState *state,
THTensor *input,
THTensor *gradOutput,
THTensor *gradInput);
...
torch/extra/nn/lib/THNN/init.c
に#include "generic/NewClass.c"
#include "THGenerateFloatTypes.h"
...
function NewClass:updateOutput(input)
input.THNN.NewClass_updateOutput(
input:cdata(),
self.output:cdata()
)
return self.output
end
function NewClass:updateGradInput(input, gradOutput)
if self.gradInput then
input.THNN.NewClass_updateGradInput(
input:cdata(),
self.gradInput:cdata(),
gradOutput:cdata()
)
return self.gradInput
end
end
...
cd torch/extra/nn/
luarocks make rocks/nn-scm-1.rockspec
require 'nn'
...
nn.NewClass()
...
Cuda実装
演算効率をさらに向上させるには、自分でCudaバージョンのプログラムを書く必要があります.
torch/extra/cunn/lib/THCUNN/
ディレクトリの下にファイルNewClassを作成する.cu ...
void THNN_CudaNewClass_updateOutput(THCState *state, THCudaTensor *input, THCudaTensor *output)
{
}
void THNN_CudaNewClass_updateGradInput(THCState *state, THCudaTensor *input, THCudaTensor *gradOutput, THCudaTensor *gradInput)
{
}
...
torch/extra/cunn/lib/THCUNN/THCUNN.h
に追加:TH_API void THNN_CudaNewClass_updateOutput(
THCState *state,
THCudaTensor *input,
THCudaTensor *output);
TH_API void THNN_CudaNewClass_updateGradInput(
THCState *state,
THCudaTensor *input,
THCudaTensor *gradOutput,
THCudaTensor *gradInput);
cd torch/extra/cunn/
luarocks make rocks/cunn-scm-1.rockspec
require 'cunn'
...
nn.NewClass()
...
テスト
torch/extra/nn/test.lua
とtorch/extra/cunn/test.lua
にテストコードを追加し、NewClassの出力が正しいかどうかをテストするために使用できます.具体的には、既存のテストコードを参照してください.追加後、th -lnn -e "nn.test{'NewClass'}"
を実行するとテストできます.