TabNet(GCP AIPlatformの組み込みアルゴリズム)をBQMLで使う
GCPのAI Platformにいい感じの組み込みアルゴリズムがリリースされたようなので試してみる。
AI Platform 上の TabNet: 高パフォーマンスで説明可能な表形式ラーニング
特徴量の重要度が説明可能でありながらディープニューラルネットワークと同様のパフォーマンスを出すとのこと
データのフォーマットと前処理
以下ページ参照
TabNet 組み込みアルゴリズムを使用したトレーニング
データフォーマット
2020/10/23時点では、入力フォーマットは単一CSVファイルのみ
1.ヘッダー行を削除する
2.1 列目をターゲット列にする
前処理の注意点(特に)
数値項目をカテゴリとして扱いたい場合
データをカテゴリとして扱う場合は、列内のすべての整数値の前に数値以外の接尾辞を追加します。例:{code_101, code_102, code_103}
一意の値が少ない数値項目を数値として扱いたい場合
データを数値として扱う場合は、列内のすべての整数値を浮動小数点に変換します。例:{101.0, 102.0, 103.0}
ジョブの中で行われるデータの変換
以下の順番でデータ変換が行われる
1.分割の割合が指定されている場合は、トレーニング データセットを検証データセットとテスト データセットに分割します。
2.特徴の欠損が 10% を超える行を削除します。
3.列の平均値を使用して、欠損している数値を埋めます。
トレーニング
パラメータはデフォルトでセッティングされているものでもそれっぽい精度にはなる
AI Platformのジョブメニューから、新規トレーニングジョブ、組み込みアルゴリズムによるトレーニングをクリック
トレーニングデータを選択
2020/10/23時点ではGCSバケットに格納されている単一ファイルのみ
トレーニングデータ、モデルの出力ディレクトリを指定(GCS)
必要に応じてテストデータの割合も指定可能
Modelタイプの指定
regression(回帰)かclassification(分類)の選択をする
必要ならばMax Stepも指定する
(トレーニングが全然終わらない場合あるので、そういう場合はMax Stepを指定するとそこで打ち切れる)
ジョブID記入ととスケール階層を選択する
ジョブIDは一意なら何でもOK
スケール階層はBASICとCUSTOMがある(1GBくらいのデータをBASICで動かすとメモリ不足エラーが発生した)
BQMLのモデル作成
上記のジョブが完了したら、BQMLモデルとして作成可能
参考:BigQuery MLでTensorFlowのモデルを呼び出す
CREATE MODEL dataset.model_name
OPTIONS(MODEL_TYPE='TENSORFLOW',
MODEL_PATH="gs://指定した出力フォルダ/model/*")
BQMLで予測
以下を参考に、入力を整えてあげてML.predictを実行する
TabNet 組み込みアルゴリズムのスタートガイド
ポイントはcsv_rowというカラム(必要項目をCSVっぽくしたカラム)とkey(これは多分何でもOK)というカラムが必要
with pred_table
as
(
select
concat(
column1,
',',
column2
) as csv_row,
column1 as key,
column1,
column2
from
table
)
select
*
from
ml.predict(MODEL `モデル名`, table pred_table) pred
出力結果
分類モデルの場合、以下のような感じの出力が返ってくる
aggregated_mask_values | key | logits | predicted_classes | csv_row |
---|---|---|---|---|
0.0 | 1 | -0.6332 | 0 | column1,column2 |
0.5 |
- aggregated_mask_values:どのカラムの特徴量が効いてるかを表してるっぽい、特徴量分の行が出力されるので、不要であればSELECT句から除外する
- key:SQLで設定したkeyの値
- logits:predicted_class判定のための確率、0~1の範囲で表したいときは、 round(exp(logits)/(1+exp(logits)),3) のようにする
- predicted_class:分類結果
競馬予測してみる
先週の菊花賞(2020年10月25日)を予測してみる。
(3着以内確率が高い順に表示)
コントレイルがあまり評価されていない。。
トレーニングデータとか入力データに欠損データが多すぎるのか。。
horse_name | prob |
---|---|
バビット | 0.412 |
アリストテレス | 0.41 |
ダノングロワール | 0.342 |
サトノフラッグ | 0.342 |
ロバートソンキー | 0.32 |
レクセランス | 0.298 |
ディアマンミノル | 0.267 |
コントレイル | 0.243 |
ガロアクリーク | 0.229 |
ブラックホール | 0.225 |
ディープボンド | 0.216 |
サトノインプレッサ | 0.213 |
マンオブスピリット | 0.209 |
ビターエンダー | 0.208 |
ターキッシュパレス | 0.167 |
ヴァルコス | 0.11 |
キメラヴェリテ | 0.086 |
ヴェルトライゼンデ | 0.053 |
遭遇したエラー
The replica master 0 ran out-of-memory
メモリー不足みたいなので、ちょいと上のインスタンスで動かす
以下インスタンスごとの料金
https://cloud.google.com/ai-platform/training/pricing?hl=ja
tensorflow.python.framework.errors_impl.InvalidArgumentError: Restoring from checkpoint failed.
エラー起きてインスタンス変えたり変更入れてまた動かすとき、出力フォルダをエラー時のジョブと同じところに指定しちゃうと前回のジョブの途中からやり始めるのでエラーになるので注意
メモ
多クラス分類する場合
/artifacts/metadata.json
の一番最後に分類の順番ある
Author And Source
この問題について(TabNet(GCP AIPlatformの組み込みアルゴリズム)をBQMLで使う), 我々は、より多くの情報をここで見つけました https://qiita.com/yakamazu/items/f86760be90dca360c981著者帰属:元の著者の情報は、元のURLに含まれています。著作権は原作者に属する。
Content is automatically searched and collected through network algorithms . If there is a violation . Please contact us . We will adjust (correct author information ,or delete content ) as soon as possible .