pytorchで学習済みのモデルをROS C++(libtorch)で使う


はじめに

pytorchで学習したモデルをROSで使うとき、

  • モデル学習時の環境がpython3だとpython2環境のROSでの実装が面倒
  • pythonだとどうしても推論の速度が出ない

などで困っていたので、libtorchを使ってモデルをC++で動かしてみたときのメモです。
(ROS2だったらpython3の環境でも使えたけどやっぱり遅い...)

環境

Ubuntu18.04, Ubuntu16.04
ROS Melodic, ROS Kinetic

libtorchのビルド

libtorchはこちらに書いてあるようにダウンロードすることが可能ですが、2019/8/3現在、旧型のABIでビルドされています( https://discuss.pytorch.org/t/sos-libtorch-conflict-with-ros/48977/6 )。

これを用いると、ROSの新型ABIとの競合ができないのでソースからビルドする必要がありました。

ビルドは以下のようにしたらうまくいきました。(https://github.com/pytorch/pytorch/blob/master/docs/libtorch.rst )

$ git clone https://github.com/pytorch/pytorch
$ cd pytorch
$ git submodule update --init --recursive
$ python setup.py build
$ mkdir build_libtorch && cd build_libtorch
$ python ../tools/build_libtorch.py

学習済みモデルの変換

学習済みモデルの変換には、tracingとannotationによる方法がありますが、ここではtracingによる手法の実装例を紹介します。
annotationによる手法はこちらを参照してください。

# 実装例
import torch

model = Model() # nn.Module
model_path = "weight.pkl" # 学習済みの重み
model.load_state_dict(torch.load(model_path))
expample = torch.rand(1, 3, 64, 64) # 入力のshapeに合わせる
traced_net = torch.jit.trace(model, example)
traced_net.save("model.pt")

ROSで学習済みモデルでの推論

rosのパッケージのCMakeList.txtの書き方の例はこんな感じです。
find_package(Torch REQUIRED) の前にcaffe2とtorchのDIRを設定したらmake通りました。

# CMakeLists.txt
cmake_minimum_required(VERSIONS 2.8.3)
project(hoge)

add_compile_options(-std=c++11)

find_package(catkin REQUIRED COMPONENTS
  roscpp
  geometry_msgs
)

set(Caffe2_DIR "$ENV{HOME}/pytorch/torch/share/cmake/Caffe2")
set(Torch_DIR "$ENV{HOME}/pytorch/torch/share/cmake/Torch")
find_package(Torch REQUIRED)

include_directories(
  ${catkin_INCLUDE_DIRS}
)
add_executable(hoge src/hoge.cpp)
target_link_libraries(hope ${catkin_LIBRARIES} ${TORCH_LIBRARIES))

プログラム自体は、単純に先ほど保存したモデルをloadして、forwardにinputを与えるだけで推論してくれます。
C++でのtorchの実装はこちらを参考にしました。

// hoge.cpp
#include "ros/ros.h"
#include <geometry_msgs/Twist.h>
#include <torch/script.h>

class Hoge{
public:
  Hoge();
  void process();
private:
  ros::NodeHnadle nh;
  ros::Publisher vel_pub;
  torch::jit::script::Module module;
};

Hoge::Hoge(){
  vel_pub = nh.advertise<geometry_msgs::Twist>("/cmd_vel", 1);
  module = torch::jit::load("model.pt");
}

void Hoge::process(){
  ros::Rate loop_rate(10);
  while(ros::ok()){
    torch::Tensor input = torch::ones({1, 3, 64, 64}));
    auto output = module.foward({input}).toTensor();
    vel.linear.x = output[0].item<float>();
    vel.angular.z = output[1].item<float>();
    vel_pub.publish(vel);
    loop_rate.sleep();
    ros::spinOnce();
  }
}

int main(int args, char **argv){
  ros::init(args, args, "hoge");
  Hoge hoge;
  hoge.process();
  return 0;
}

参考URL

PYTORCH C++ API
LOADING A PYTORCH MODEL IN C++
PyTorchで学習したモデルをC++から使う @cashiwamochi
pythonで学習したDNNモデルをC++から利用する(PyTorch & libtorch版) @nmatsui