LibTorchを利用してPyTorch訓練したモデルを呼び出す
PyTorchは現在1.1安定バージョンにリリースされており、新しい機能によりモデルの導入がより簡単になりました.本稿では、PyTorchが訓練したモデルを呼び出す方法を記録します.実際には、公式の強力なLibTorchライブラリを利用しています.
LibTorchのインストール
インストールといっても、実は公式のLibTorchパッケージをダウンロードするだけで、公式サイトからPyTorch(1.1)、libtorch、およびcudaのバージョンを選択し、ダウンロードリンクが表示されます.ここではcuda 9です.0のリンクhttps://download.pytorch.org/libtorch/cu90/libtorch-shared-with-deps-latest.zipダウンロードしてパスを探して解凍します.解凍してそこに置いて動かない!!
PyTorchモデルトレーニング
ここでは、最も単純なResNet 50の事前トレーニングモデルを使用しています.ここで、トレースモデルを保存するコードは以下の通りです.
なお、ここでは
C++呼び出し訓練したモデル
C++呼び出しモデルのコードを書く前に、CMakeListsファイルを書きます.
ここでは
このうち
LibTorchのインストール
インストールといっても、実は公式のLibTorchパッケージをダウンロードするだけで、公式サイトからPyTorch(1.1)、libtorch、およびcudaのバージョンを選択し、ダウンロードリンクが表示されます.ここではcuda 9です.0のリンクhttps://download.pytorch.org/libtorch/cu90/libtorch-shared-with-deps-latest.zipダウンロードしてパスを探して解凍します.解凍してそこに置いて動かない!!
PyTorchモデルトレーニング
ここでは、最も単純なResNet 50の事前トレーニングモデルを使用しています.ここで、トレースモデルを保存するコードは以下の通りです.
import torch
import torchvision.models as models
from PIL import Image
import numpy as np
image = Image.open("build/airliner.jpg") # build
image = image.resize((224, 224),Image.ANTIALIAS)
image = np.asarray(image)
image = image / 255
image = torch.Tensor(image).unsqueeze_(dim=0)
image = image.permute((0, 3, 1, 2)).float()
model = models.resnet50(pretrained=True)
model = model.eval()
resnet = torch.jit.trace(model, torch.rand(1,3,224,224))
# output=resnet(torch.ones(1,3,224,224))
output = resnet(image)
max_index = torch.max(output, 1)[1].item()
print(max_index) # ImageNet1000
resnet.save('resnet.pt')
なお、ここでは
jit
のtraceトラッキングモデルを用いているが、最後の入力で飛行機のカテゴリが得られることは間違いない.ImageNet 1000クラスのシーケンス番号カテゴリは、ここでこのコードを参照してルートディレクトリの下にresnet.pt
のファイルを生成することができる.この文書は、次にC++が呼び出す必要がある.C++呼び出し訓練したモデル
C++呼び出しモデルのコードを書く前に、CMakeListsファイルを書きます.
cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(example_torch)
set(CMAKE_PREFIX_PATH "XXX/libtorch") // libtorch
find_package(Torch REQUIRED)
find_package(OpenCV 3.0 QUIET)
if(NOT OpenCV_FOUND)
find_package(OpenCV 2.4.3 QUIET)
if(NOT OpenCV_FOUND)
message(FATAL_ERROR "OpenCV > 2.4.3 not found.")
endif()
endif()
add_executable(${PROJECT_NAME} "main.cpp")
target_link_libraries(${PROJECT_NAME} ${TORCH_LIBRARIES} ${OpenCV_LIBS})
set_property(TARGET ${PROJECT_NAME} PROPERTY CXX_STANDARD 11)
ここでは
CMAKE_PREFIX_PATH
のパスを設定します.このパスはlibtorchを解凍するパスです.そうしないとlibtorchライブラリにリンクできません.その中にはOpenCVの構成も設定されています.具体的なOpenCVのインストールについてはここで説明します.そしてC++がPyTorchモデルを呼び出すコード#include
#include
#include
#include
#include
#include
#include
void TorchTest(){
std::shared_ptr<torch::jit::script::Module> module = torch::jit::load("../resnet.pt");
assert(module != nullptr);
std::cout << "Load model successful!" << std::endl;
std::vector<torch::jit::IValue> inputs;
inputs.push_back(torch::zeros({1,3,224,224}));
at::Tensor output = module->forward(inputs).toTensor();
auto max_result = output.max(1, true);
auto max_index = std::get<1>(max_result).item<float>();
std::cout << max_index << std::endl;
}
void Classfier(cv::Mat &image){
torch::Tensor img_tensor = torch::from_blob(image.data, {1, image.rows, image.cols, 3}, torch::kByte);
img_tensor = img_tensor.permute({0, 3, 1, 2});
img_tensor = img_tensor.toType(torch::kFloat);
img_tensor = img_tensor.div(255);
std::shared_ptr<torch::jit::script::Module> module = torch::jit::load("../Train/resnet.pt");
torch::Tensor output = module->forward({img_tensor}).toTensor();
auto max_result = output.max(1, true);
auto max_index = std::get<1>(max_result).item<float>();
std::cout << max_index << std::endl;
}
int main() {
// TorchTest();
cv::Mat image = cv::imread("airliner.jpg");
cv::resize(image,image, cv::Size(224,224));
std::cout << image.rows <<" " << image.cols <<" " << image.channels() << std::endl;
Classfier(image);
return 0;
}
このうち
TorchTest
関数は簡単なプレゼンテーションを行っただけで、Classfier
はOpenCVによってピクチャを読み取り、libtorchの関数によってMat
フォーマットをTensor
に変換し(ここでは次元を変換した.OpenCVの次元は[H,W,C]、PyTorchモデルに必要なのは[C,H,W])、最後にPythonコードと同じ答えを出力することができるからである.ここで重要ないくつかの関数は、torch::from_blob()
です.この関数はMat
タイプをTensor
タイプに変換します.torch::jit::load()
:この関数は、名前の通りモデルをロードする関数です.module->forward()
:モデルの順方向伝播の関数で、入力値はvectorタイプmax()
を使用することを推奨します:この関数はlibtorchのmaxで、c++のtupleタイプ(最初の値は次元上の最大値、2番目の値は最大値のシーケンス番号)を返しますので、std::get<1>(max_result)
を使用してシーケンス番号を取り出します.これはtupleタイプの取り出し方法です.