torch学習ノート1:カスタムレイヤの実装


私たちが自分のideaを実装する場合、torchが持っているモジュールと関数はもう満足できません.私たちは自分でレイヤ(またはクラス)を実装する必要があります.一般的には、カスタムレイヤを既存のtorchモジュールに追加することです.
インプリメンテーション
lua実装
カスタムレイヤの機能がtorchに既存の関数を呼び出すことで実現できる場合は、luaで実現するだけで、torchのドキュメントにも簡単な説明があります.新しいクラスを実現します
  • torchディレクトリの下(torch/extra/nn/)にファイルNewClass.lua
  • を作成
  • nnの他のluaファイルの構造を参照してテンプレートを書き、対応する関数で所望の機能
  • を実現する.
    --    , nn.Module  
    local NewClass, Parent = torch.class('nn.NewClass', 'nn.Module')
    --     
    function NewClass:__init()
       Parent.__init(self)
    end
    --    
    function NewClass:updateOutput(input)
    end
    --    
    function NewClass:updateGradInput(input, gradOutput)
    end
    --        ,     ,            ,         
    function NewClass:accGradParameters(input, gradOutput)
    end
    
  • nnのinit.luaの末尾に
  • を追加
    require('nn.NewClass')
  • nnモジュール
  • を再インストール
    cd torch/extra/nn/
    luarocks make rocks/nn-scm-1.rockspec
  • のインストールに成功した後、自分のコードにカスタムクラス
  • を使用しました.
    require 'nn'
    ...
    nn.NewClass()
    ...

    CPU実装
    torchの関数で必要な機能が実現できない場合は、Cプログラムを自分で書いてコア機能を実現し、NewClassで実現する必要がある.luaで呼び出されます.
  • torch/extra/nn/lib/THNN/generic/ディレクトリの下にファイルNewClassを作成する.c
  • nnに既存の実装を参照し、関数に必要な機能
  • を実装する.
    ...
    void THNN_(NewClass_updateOutput)(
              THNNState *state,
              THTensor *input,
              THTensor *output)
    {
    
    }
    
    void THNN_(NewClass_updateGradInput)(
              THNNState *state,
              THTensor *input,
              THTensor *gradOutput,
              THTensor *gradInput)
    {
    
    }
    ...
  • は実装された関数を宣言し、torch/extra/nn/lib/THNN/generic/THNN.h
  • を追加する.
    ...
    TH_API void THNN_(NewClass_updateOutput)(
              THNNState *state,            
              THTensor *input,             
              THTensor *output);           
    TH_API void THNN_(NewClass_updateGradInput)(
              THNNState *state,            
              THTensor *input,            
              THTensor *gradOutput,      
              THTensor *gradInput);    
    ...    
  • include、torch/extra/nn/lib/THNN/init.c
  • を追加
    #include "generic/NewClass.c"
    #include "THGenerateFloatTypes.h"
  • ニュークラスです.luaでCPUバージョンを呼び出す関数
  • ...
    function NewClass:updateOutput(input)
       input.THNN.NewClass_updateOutput(
          input:cdata(),
          self.output:cdata()
       )
       return self.output
    end
    
    function NewClass:updateGradInput(input, gradOutput)
    
       if self.gradInput then
          input.THNN.NewClass_updateGradInput(
             input:cdata(),
             self.gradInput:cdata(),
             gradOutput:cdata()
          )
          return self.gradInput
       end
    end
    ...
  • 再コンパイルインストールnn
  • cd torch/extra/nn/
    luarocks make rocks/nn-scm-1.rockspec
  • のインストールに成功した後、自分のコードにカスタムクラス
  • を使用しました.
    require 'nn'
    ...
    nn.NewClass()
    ...

    Cuda実装
    演算効率をさらに向上させるには、自分でCudaバージョンのプログラムを書く必要があります.
  • torch/extra/cunn/lib/THCUNN/ディレクトリの下にファイルNewClassを作成する.cu
  • cunnに既存の関数を参照し、関数機能
  • を実現する
    ...
    void THNN_CudaNewClass_updateOutput(THCState *state, THCudaTensor *input, THCudaTensor *output)
    {
    
    }
    
    void THNN_CudaNewClass_updateGradInput(THCState *state, THCudaTensor *input, THCudaTensor *gradOutput, THCudaTensor *gradInput)
    {
    
    }
    ...
  • 宣言関数、torch/extra/cunn/lib/THCUNN/THCUNN.hに追加:
  • TH_API void THNN_CudaNewClass_updateOutput(
              THCState *state,
              THCudaTensor *input,
              THCudaTensor *output);
    TH_API void THNN_CudaNewClass_updateGradInput(
              THCState *state,
              THCudaTensor *input,
              THCudaTensor *gradOutput,
              THCudaTensor *gradInput);
  • ニュークラスです.luaでGPUバージョンを呼び出す関数は,CPUバージョンと同様にTHNNで
  • を呼び出す.
  • 再コンパイルインストールcunn
  • cd torch/extra/cunn/
    luarocks make rocks/cunn-scm-1.rockspec
  • のインストールに成功した後、自分のコードにカスタムクラス
  • を使用しました.
    require 'cunn'
    ...
    nn.NewClass()
    ...

    テストtorch/extra/nn/test.luatorch/extra/cunn/test.luaにテストコードを追加し、NewClassの出力が正しいかどうかをテストするために使用できます.具体的には、既存のテストコードを参照してください.追加後、th -lnn -e "nn.test{'NewClass'}"を実行するとテストできます.