ONNXRuntime の Ruby バインディングを動かしてみた


はじめに

 今となってはそこまで珍しくないかもしれませんが、RubyからDeep Learningの推論エンジンを動してみたので記事にします。
 ONNX Runtimeとは、Microsoftが中心となって開発しているOpen Neural Network Exchange (ONNX) 用の推論エンジンです。つい先日、Chartkickなどの開発で有名なAndrew Kane氏が、ONNXRuntimeのRuby向けバインディングを作成されました。今日はこれを使ってみます。Ankane氏は、最近XgbLightGBMと相次いでRuby向け機械学習バインディングを制作・公開されており、ONNXRuntimeはその最新作です。

Ankane氏公式ブログより

I’m happy to announce it’s now possible to build advanced models in TensorFlow, Scikit-learn, PyTorch, and a number of other tools, and score them in Ruby with minimal friction.

ONNXに変換すれば、だいたいのDeepLearningのモデルはRubyから簡単に呼び出せるようになりそうですね。

使ってみた

せっかくなのでVGGとかではなくResNetを動かしてみましょう。ここからmodelをダウンロードできます。

GUIツールキットにFlammarionを使います。numpyの代わりにNArrayを使います。画像のリサイズ・ピクセル取得は、いろいろやり方があると思うのですが、今回は筆者のブログの通りmini_magickを使用してみました。

require 'mini_magick'
require 'numo/narray'
require 'onnxruntime'
require 'flammarion'

SFloat = Numo::SFloat

Flam = Flammarion::Engraving.new

model = OnnxRuntime::Model.new('resnet152v2.onnx')

path = File.expand_path(ARGV[0])
img = MiniMagick::Image.open(path)

def preprocess(img)
  img.resize '224x224!'
  img_data = SFloat.cast(img.get_pixels)
  img_data = img_data.transpose(2, 0, 1)
  img_data = img_data.expand_dims(0)
  mean_vec = SFloat[0.485, 0.456, 0.406]
  stddev_vec = SFloat[0.229, 0.224, 0.225]
  norm_img_data = SFloat.zeros(1, 3, 224, 224)
  norm_img_data.shape[1].times do |i|
    norm_img_data[true, i, true, true] =
      (img_data[true, i, true, true] / 255 - mean_vec[i]) / stddev_vec[i]
  end
  norm_img_data.to_a
end


labels = File.readlines("synset.txt")

input = preprocess(img)
result = model.predict(data: input)

score = result['resnetv27_dense0_fwd'][0]

lab2, sco2 = labels.zip(score).sort_by{|i| -i[1]}[0..9].reverse.transpose

Flam.orientation = :horizontal
Flam.image("#{path}\" height=400") # HTMLに埋め込まれるため変則的な方法で画像サイズ指定
Flam.pane("graph").plot(y: lab2, x: sco2,
          type: :bar, orientation: 'h',
          height: 400, margin: {l: 400})

Flam.wait_until_closed

こんな風に表示されました。ちゃんと猫であると判定してくれていますね。

感想

model.inputs
model.outputs

で、何をインプットしたらいいか、何をアウトプットすれば良いかがすぐわかるのが便利です。とても使いやすいです。

そのほか

ONNXRuntime以外の推論エンジンでMenohというのがありまして、こちらもRuby向けのバインディングがあります。しかし、ONNXRuntimeの登場で、あまり使われなくなるんじゃないかなという気がします。

参考資料