Streamlit上でPyCaretを動かす方法


こんにちは。Qiita初投稿です。

はじめに

PythonコードだけでWebアプリを作れるStreamlit
数行のコードで機械学習の前処理から推定までできるPyCaretというものがあります。
Streamlit上でPyCaretを動かす方法を調べると、ライブラリをイジる方法しかヒットせず、少しハードルが高かったのでライブラリをイジらずに動かす方法を試していこうと思います。

環境

Python 3.7
Streamlit 1.5.1
PyCaret 2.3.6

手順

クラス分類や回帰など問題ごとにライブラリが分かれているので、使用するものをimportします。
今回は回帰で試していきます。

from pycaret.regression import *

StreamlitとDataFrameも扱うので、そちらもimportしておきます。

import pandas as pd
import streamlit as st

データの取得

今回はボストンの住宅価格データセットを使います。

from pycaret.datasets import get_data
data = get_data("boston")

前処理

前処理はsetup関数を使います。
デフォルトではデータの型推定が正しいか入力を求められますが、Streamlitでは入力を返せないので無効にします(html=False, silent=True)。

pipe = setup(data, target="medv", html=False, silent=True)

モデル比較

PyCaretでは学習に使うモデルを選択することができます。
それぞれのモデルの比較はcompare_models関数で実行できます。
実行結果をpandasのDataFrameで取り出し、表示させます。

best = compare_models() # モデル比較
best_model_results = pull() # 比較結果の取得
st.write(best_model_results) # 比較結果の表示

デフォルトでは決定係数R2のスコアが良い順にソートされて出力されます。

モデル作成

モデル作成はcreate_model関数を使用します。
モデル比較で表示されるモデル名(name)と、関数の引数で指定するモデル名(ID)が異なります。
(いつからからモデル比較結果のIndexにモデルIDが表示されるようになりました)

model = create_model("et") # Extra Trees Regressorを使用

複数のモデルでブレンドモデルやアンサンブルモデルを作成することもできます。

ブレンドモデル
et = create_model("et") # Extra Trees Regressorを使用
lightgbm = create_model("lightgbm") # Light Gradient Boosting Machine
model = blend_models(estimator_list=[et, lightgbm])
アンサンブルモデル
ensembled_model = ensemble_model(model)

モデル作成は特にStreamlit向けに調整する事項はありません。

モデルの可視化

モデルの可視化はplot_model関数を使用します。
可視化の内容は学習曲線や残差などいくつか用意されていますが、
リファレンスにもある通り、すべての可視化内容がStreamlitで表示できるようにはなっていません。
(対応していても表示に時間がかかるものもありました)
https://pycaret.readthedocs.io/en/latest/api/regression.html#pycaret.regression.plot_model

plot_model(model, plot="cooks", display_format="streamlit")

クックの距離を描写するとこんな感じです。

予測

予測はpredict_model関数を使います。
特にデータを指定しなければ、setup実行時にテスト用に分割してあったデータセットの一部で予測を行います。

predictions = predict_model(model)
st.write(predictions)

予測結果はLabelというカラムに出力されます。
medv(住宅価格)を予測しましたが、何も考えずにモデルを作ったわりには、それなりの精度がでています。

st.cache

Streamlitの性質上、ユーザがUIを操作するたびにPyCaretの関数が再実行されることがあります。
@st.cacheを使って対策を取っておくと良いでしょう。

create_modelの例
@st.cache(allow_output_mutation=True)
def create_model_cache(estimator):
    return create_model(estimator)

さいごに

setupとplot_modelでそれぞれオプションを指定し、また必要に応じてst.cacheを活用すれば
Streamlit上でPyCaretを動かすことができます。
サンプルプログラムをGitHubに公開しているので、よければ試してみてください。

参考