機械学習モデルのパラメータと予測平面の関係を可視化するwebアプリを構築した話【tidymodels,Shinyapps】


はじめに

早速ですが、このようなwebアプリを作成しました。

UIはこんな感じでパラメタと予測・分類平面や性能を可視化します。

まだまだプロトタイプの位置づけとして未実装や詳細を詰めていない部分が多くありますが、似たようなことをしたい方の参考になればと思います。

  • 本記事で扱わないこと
    • 機械学習モデルの理論やパラメータの説明
    • 学習済みモデルの性能確認について
    • shinyでのwebアプリ構築方法
  • 本記事で扱うこと
    • 予測平面や性能評価を確認するために使用したコード
    • shiny appsを使って自作のshinyアプリを一般公開する方法

もし、学習済みモデルの性能を確認するためのアプリが欲しかったのであれば、tidymodelsの「shinymodels」の開発を追いかけてみてください。
モデルの性能評価や解釈についてを扱うパッケージで、現在はまだ対応しているモデルが少ないですが、今後追加されるかもしれません。

動機

  • 機械学習のパラメータを動かした時、過学習するのか分からなくなる
  • 決定境界、予測平面、見てるだけでオモロイ
  • Rでナニカを作って、webアプリとして公開したい場合どうしたらいいか知りたかった
  • tidymodelsの扱うモデル一覧紹介的なモノを作りたかった

動機としては上記のようなことを思っていました。

早速、shiny appsでのwebアプリの公開

Rにはshinyパッケージというwebアプリケーションを作るためのフレームワークがあります。
shinyの扱い方については以下でチュートリアルを紹介しています。

shinyパッケージを使うことでRの実行結果を容易に反映できるwebアプリケーションが作成できます。
しかしローカルホスト上に立ち上がるだけで、自分のPC内だけでなく一般公開するにはサーバーが必要になってきます。
shinyapps.ioでは、shinyサービスを公開することに特化した環境を提供しています。

shinyappsへのwebアプリのアップロード方法は、

  • shinyapps.ioへの登録
  • コーディング済のshinyアプリのコードが置かれたディレクトリをrsconnectで指定
library(rsconnect)
library(shiny)

rsconnect::setAccountInfo(name='ringa-hyj',
                          token='',
                          secret='')

rsconnect::deployApp('C:\\Users\\ringa\\Desktop\\shinyapps\\ml_param')

上記のnameはshinyapps.ioへの登録名、tokenやsecretはAccount > Tokensから発行してください。

shinyapps.ioでは無料でいくつかのwebアプリを公開できますが有料プランもあります。
有料になるとメモリサイズの拡張や、アクセス制限を付けたアプリ公開が可能になります。

あとはアプリケーション起動用のURLが発行されているので接続するだけです。
必要のないときはsleepになり、アクセスされると起動する仕組みです。

以上。

余談。実装コード一部紹介

内部的に実行しているRのコードも簡単に紹介しておきます。

まず利用しているデータはmodeldataパッケージの「biomass」と「two_class_dat」を使っています。
global.Rにてmodeldataから訓練、テストデータを切り分けます。
予測・分類平面を可視化するための平面グリッドもglobal.Rの中で作成しています。


spl_two <- initial_split(two_class_dat,prop = 0.8,strata = Class)
train_two <- spl_two %>% training()
test_two <- spl_two %>% testing()
A=seq(min(two_class_dat$A),max(two_class_dat$A),length=100)
B=seq(min(two_class_dat$B),max(two_class_dat$B),length=100)
preds_base <- expand_grid(A,B)

性能評価のためのコードや、パラメータを代入可能にしたコードをglobal.R内で関数化し、server.Rで実行します。


roc_auc_plot <- function(data,model,title){

predict(model,data,type = "prob") %>% 
  bind_cols(data) %>% 
  roc_curve(Class,.pred_Class1) %>% 
  autoplot()+
  theme_bw()+
  ggtitle(title)

}
rpart_class_fit <- function(p1,p2,p3){
  
  rpart_mod <<- decision_tree(
    cost_complexity = p1,
    tree_depth = p2,
    min_n = p3
  ) %>%
    set_engine("rpart") %>%
    set_mode("classification")
  
  rpart_recipe <<- 
    recipe(formula = Class ~ ., data = train_two) %>% 
    step_normalize(all_numeric(), -all_outcomes()) %>% 
    step_novel(all_nominal(), -all_outcomes()) %>% 
    step_dummy(all_nominal(), -all_outcomes(), one_hot = TRUE) %>% 
    step_zv(all_predictors()) 
  
  wfl_rpart <<- workflow() %>% 
    add_model(rpart_mod) %>% 
    add_recipe(rpart_recipe)
  
  set.seed(123) 
  res <<- wfl_rpart %>% fit(train_two)
  
}

モデルのパラメタをui.Rにて入力する


...略
numericInput(inputId = "ranger_mtry_class",label = "this is mtry",value = "2",min = 0,width = '50%'),
numericInput(inputId = "ranger_min_n_class",label = "this is min_n",value = "2",min = 0,width = '50%'),
numericInput(inputId = "ranger_trees_class",label = "this is trees",value = "1000",min = 0,width = '50%'),
略...                  

関連URL

・言うまでもなく公式のドキュメント

・類似のデプロイ方法紹介記事

最後に

ご指摘、機能追加リクエスト等あれば歓迎です。