Wio TerminalでTensorflow liteのHello world


Tensorflow lite for Microcontrollersは、メモリが数十キロバイトしかないマイクロコントローラでも機械学習モデルを実行できるように設計されているようで、Arduinoベースのマイクロコントローラでも、加速度計データからのジェスチャー分類、カメラデータを使用した画像分類などができるとのことで、手持ちのWio Terminalで試してみました。

Arduino Tensorflow liteライブラリのインストール

Arduino IDEのライブラリマネージャ上で、Arduino TensorFlow Liteキーワードで検索してください。下図の通り、Tensorflow liteのライブラリがフィルタリング表示されていますが、precompiledではない最新バージョンをインストールしてください。

Hello Worldを動かしてみる

ライブラリのインストールが完了すれば、[ファイル]_[スケッチ例]メニューから、Arduino_TesorFlowLite->hello_worldをロードすることができます。

hello_worldのサンプルプログラムは、そのままコンパイルと書き込みに成功するはずです。書き込み後、シリアルプロッタを表示すると下図のようにSin波形が描画されていくのがわかると思います。

Wio Terminalの画面上にプロット

せっかく画面があるということなので、生成されたSine波を画面上にプロットしてみます。まずは、TFT_eSPI.hをインクルードし、fillCircleでサークルの描画となります。
setup()関数でTFT_eSPIの初期化、loop()関数内のHandleOutputの後あたりに画面表示用のdrawSine関数を追加しました。

hello_world.c
#include <TensorFlowLite.h>

#include "main_functions.h"

#include "constants.h"
#include "output_handler.h"
#include "sine_model_data.h"
#include "tensorflow/lite/experimental/micro/kernels/all_ops_resolver.h"
#include "tensorflow/lite/experimental/micro/micro_error_reporter.h"
#include "tensorflow/lite/experimental/micro/micro_interpreter.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/version.h"

#include"TFT_eSPI.h"
TFT_eSPI tft;

// Globals, used for compatibility with Arduino-style sketches.
namespace {
tflite::ErrorReporter* error_reporter = nullptr;
const tflite::Model* model = nullptr;
tflite::MicroInterpreter* interpreter = nullptr;
TfLiteTensor* input = nullptr;
TfLiteTensor* output = nullptr;
int inference_count = 0;

// Create an area of memory to use for input, output, and intermediate arrays.
// Finding the minimum value for your model may require some trial and error.
constexpr int kTensorArenaSize = 2 * 1024;
uint8_t tensor_arena[kTensorArenaSize];
}  // namespace

// The name of this function is important for Arduino compatibility.
void setup() {
  tft.begin();
  tft.setRotation(3);
  tft.fillScreen(TFT_BLACK);

  // Set up logging. Google style is to avoid globals or statics because of
  // lifetime uncertainty, but since this has a trivial destructor it's okay.
  // NOLINTNEXTLINE(runtime-global-variables)
  static tflite::MicroErrorReporter micro_error_reporter;
  error_reporter = &micro_error_reporter;

  // Map the model into a usable data structure. This doesn't involve any
  // copying or parsing, it's a very lightweight operation.
  model = tflite::GetModel(g_sine_model_data);
  if (model->version() != TFLITE_SCHEMA_VERSION) {
    error_reporter->Report(
        "Model provided is schema version %d not equal "
        "to supported version %d.",
        model->version(), TFLITE_SCHEMA_VERSION);
    return;
  }

  // This pulls in all the operation implementations we need.
  // NOLINTNEXTLINE(runtime-global-variables)
  static tflite::ops::micro::AllOpsResolver resolver;

  // Build an interpreter to run the model with.
  static tflite::MicroInterpreter static_interpreter(
      model, resolver, tensor_arena, kTensorArenaSize, error_reporter);
  interpreter = &static_interpreter;

  // Allocate memory from the tensor_arena for the model's tensors.
  TfLiteStatus allocate_status = interpreter->AllocateTensors();
  if (allocate_status != kTfLiteOk) {
    error_reporter->Report("AllocateTensors() failed");
    return;
  }

  // Obtain pointers to the model's input and output tensors.
  input = interpreter->input(0);
  output = interpreter->output(0);

  // Keep track of how many inferences we have performed.
  inference_count = 0;
}

int _x, _y = 0;
// The name of this function is important for Arduino compatibility.
void loop() {
  tft.fillCircle(_x, _y, 8, TFT_BLACK);
  // Calculate an x value to feed into the model. We compare the current
  // inference_count to the number of inferences per cycle to determine
  // our position within the range of possible x values the model was
  // trained on, and use this to calculate a value.
  float position = static_cast<float>(inference_count) /
                   static_cast<float>(kInferencesPerCycle);
  float x_val = position * kXrange;

  // Place our calculated x value in the model's input tensor
  input->data.f[0] = x_val;

  // Run inference, and report any error
  TfLiteStatus invoke_status = interpreter->Invoke();
  if (invoke_status != kTfLiteOk) {
    error_reporter->Report("Invoke failed on x_val: %f\n",
                           static_cast<double>(x_val));
    return;
  }

  // Read the predicted y value from the model's output tensor
  float y_val = output->data.f[0];

  // Output the results. A custom HandleOutput function can be implemented
  // for each supported hardware target.
  HandleOutput(error_reporter, x_val, y_val);
  drawSine(x_val, y_val);

  // Increment the inference_counter, and reset it if we have reached
  // the total number per cycle
  inference_count += 1;
  if (inference_count >= kInferencesPerCycle) inference_count = 0;
}

void drawSine(float x_value, float y_value) {
  char header[32];
  sprintf(header, "x=%f y=%f", x_value, y_value);
  tft.drawString(header, 0, 0);
  _x = tft.width() * (x_value / 6.28);
  _y = tft.height() * (y_value + 1) / 2;
  tft.fillCircle(_x, _y, 8, TFT_WHITE);
  delay(10);
}

動かしてみる

こんな感じで動きました。
IMAGE ALT TEXT HERE
Tensorflow liteは、Arduino Nano 33 BLE SenseやSTM32F746 Discovery kit、Espressif ESP32-DevKitCなどいくつかのデバイスで検証されていますが、M5StackやWio Terminalで動かそうとすると、そのままのサンプルコードでは動かず、入力系のコードを変えていく必要がありそうです。