BigQueryでConfusionMatrixを計算するUDF


概要

AutoMLでバッチ処理をするとBigQueryのデータセットに予測値を吐き出すことができる。
二値の分類モデルを作った時にサクッと色々なテストデータで指標を確認したかったのでUDF化しました。

データの準備

ちなみに、AutoMLバッチ予測した結果をBigQueryに出力すると以下のような感じになります。

predicted_XXXX.tables.score predicted_XXXX.tables.value

それを実際の判定とスコアのテーブルに加工して用意します。

actual score
1 false 0.47
2 false 0.27
3 true 0.68
4 true 0.93
5 false 0.71

とりあえず完成系

CREATE TEMP FUNCTION THRESHOLDS(score FLOAT64, c INT64) AS ((
  SELECT ARRAY_AGG(STRUCT(score >= i/c AS predict, i/c AS threshold))
  FROM UNNEST(GENERATE_ARRAY(1, c-1)) AS i
));

CREATE TEMP FUNCTION INSPECT(actual BOOL, score FLOAT64, c INT64) AS ((
  SELECT ARRAY_AGG(STRUCT(
    ROUND(threshold, 2) AS threshold,
    actual,
    predict,
    CASE
      WHEN actual AND predict THEN "TP"
      WHEN NOT actual AND predict THEN "FP"
      WHEN actual AND NOT predict THEN "FN"
      WHEN NOT actual AND NOT predict THEN "TN"
    END AS class
  ))
  FROM UNNEST(THRESHOLDS(score, c))
));

WITH inspect AS (
  SELECT score, INSPECT(actual, score, 5) AS ins
  FROM UNNEST(ARRAY<STRUCT<actual BOOL, score FLOAT64>>[
    (false, 0.47),
    (false, 0.27),
    (true, 0.68),
    (true, 0.93),
    (false, 0.71)
  ])
)
SELECT 
  threshold,
  COUNTIF(class = "TP") AS TP,
  COUNTIF(class = "TN") AS TN,
  COUNTIF(class = "FP") AS FP,
  COUNTIF(class = "FN") AS FN,
FROM inspect, UNNEST(ins)
GROUP BY threshold

以上を実行すると

threshold TP TN FP FN
1 0.2 2 0 3 0
2 0.4 2 1 2 0
3 0.6 2 2 1 0
4 0.8 1 3 0 1

という結果が得られる。

解説

INSPECT(actual BOOL, score FLOAT64, c INT64)

引数はデータごとの実際の判定、予測スコア、閾値分割数となっている。
上記の例はINSPECT(actual, score, 5)この部分で5を指定しており、0~1を5等分した結果が得られるように作ってある。

WITH inspect AS (
  SELECT score, INSPECT(actual, score, 5) AS ins
  FROM UNNEST(ARRAY<STRUCT<actual BOOL, score FLOAT64>>[
    (false, 0.47),
    (false, 0.27),
    (true, 0.68),
    (true, 0.93),
    (false, 0.71)
  ])
)
SELECT *
FROM inspect

inspectの部分だけ表示すると

score ins.threshold ins.actual ins.predict ins.class
1 0.47 0.2 false true FP
0.4 false true FP
0.6 false false TN
0.8 false false TN
2 0.27 0.2 false true FP
0.4 false false TN
... ... ... ...

と各データごとに閾値を変えた判定結果が入っている。
のでinsというARRAYオブジェクトを展開して、閾値でGROUP BYしカウントすると

SELECT 
  threshold,
  COUNTIF(class = "TP") AS TP,
  COUNTIF(class = "TN") AS TN,
  COUNTIF(class = "FP") AS FP,
  COUNTIF(class = "FN") AS FN,
FROM inspect, UNNEST(ins)
GROUP BY threshold

閾値ごとの混同行列が得られる。

ちなみに、ROUND(threshold, 2)としているのでINSPECT(actual, score, 100)までは動く。

THRESHOLDS(score FLOAT64, c INT64)

scoreをcの数だけ判定するためだけの関数。

CREATE TEMP FUNCTION THRESHOLDS(score FLOAT64, c INT64) AS ((
  SELECT ARRAY_AGG(STRUCT(score >= i/c AS predict, i/c AS threshold))
  FROM UNNEST(GENERATE_ARRAY(1, c-1)) AS i
));

SELECT THRESHOLDS(score, 5) AS ts
FROM UNNEST(ARRAY<STRUCT<actual BOOL, score FLOAT64>>[
  (false, 0.47),
  (false, 0.27),
  (true, 0.68),
  (true, 0.93),
  (false, 0.71)
])
ts.predict ts.threshold
1 true 0.2
true 0.4
false 0.6
false 0.8
... ... ...

おまけ

最初の例だとINSPECTした結果をわざわざ集計しなければならなくてめんどくさい。
のでGROUP BYまでUDFでやってしまうバージョンも作成した。

CREATE TEMP FUNCTION THRESHOLDS(score FLOAT64, c INT64) AS ((
  SELECT ARRAY_AGG(STRUCT(score >= i/c AS predict, i/c AS threshold))
  FROM UNNEST(GENERATE_ARRAY(1, c-1)) AS i
));

CREATE TEMP FUNCTION INSPECT(actual BOOL, score FLOAT64, c INT64) AS ((
  SELECT ARRAY_AGG(STRUCT(
    ROUND(threshold, 2) AS threshold,
    actual,
    predict,
    CASE
      WHEN actual AND predict THEN "TP"
      WHEN NOT actual AND predict THEN "FP"
      WHEN actual AND NOT predict THEN "FN"
      WHEN NOT actual AND NOT predict THEN "TN"
    END AS class
  ))
  FROM UNNEST(THRESHOLDS(score, c))
));

CREATE TEMP FUNCTION INSPECTS(datas ARRAY<STRUCT<actual BOOL, score FLOAT64>>, c INT64) AS ((
  WITH inspect AS (
    SELECT INSPECT(actual, score, c) AS ins
    FROM UNNEST(datas)
  ),
  confusion_matrix AS (
    SELECT
      threshold,
      COUNTIF(class = "TP") AS TP,
      COUNTIF(class = "TN") AS TN,
      COUNTIF(class = "FP") AS FP,
      COUNTIF(class = "FN") AS FN
    FROM inspect, UNNEST(ins)
    GROUP BY threshold
  )
  SELECT ARRAY_AGG(STRUCT(threshold, TP, TN, FP, FN)) FROM confusion_matrix
));

WITH data AS (
  SELECT *
  FROM UNNEST(ARRAY<STRUCT<actual BOOL, score FLOAT64>>[
    (false, 0.47),
    (false, 0.27),
    (true, 0.68),
    (true, 0.93),
    (false, 0.71)
  ])
)
SELECT *
FROM UNNEST(INSPECTS((SELECT ARRAY_AGG(STRUCT(actual, score)) AS datas FROM data), 5))

INSPECTS((SELECT ARRAY_AGG(STRUCT(actual, score)) AS datas FROM data), 5)
の第一引数にactual, scoreの構造体の配列自体を渡してしまい、関数内でGROUP BYしている。

注意点として、こちらはテーブル全てを構造体の配列にしてしまうことに等しいのでデータ数が多すぎると機能しないかもしれない。

おまけ2

混同行列だけみてても仕方ないのでprecisionなどを計算する関数も作った

CREATE TEMP FUNCTION INDICATOR(TP INT64, TN INT64, FP INT64, FN INT64) AS (STRUCT(
  SAFE_DIVIDE(TP+FP, TN+FP+FN+TP) AS positive,
  SAFE_DIVIDE(TN+TP, TN+FP+FN+TP) AS accuracy,
  SAFE_DIVIDE(TP, TP+FP) AS precision,
  SAFE_DIVIDE(TP, FN+TP) AS recall,
  SAFE_DIVIDE(FP, TN+FP) AS fallout,
  SAFE_DIVIDE(2 * SAFE_DIVIDE(TP, TP+FP) * SAFE_DIVIDE(TP, FN+TP), SAFE_DIVIDE(TP, TP+FP) + SAFE_DIVIDE(TP, FN+TP)) AS f1
));
-- 関数定義省略
WITH inspect AS (
  SELECT score, INSPECT(actual, score, 5) AS ins
  FROM UNNEST(ARRAY<STRUCT<actual BOOL, score FLOAT64>>[
    (false, 0.47),
    (false, 0.27),
    (true, 0.68),
    (true, 0.93),
    (false, 0.71)
  ])
)
, confusion_matrix AS (
  SELECT 
    threshold,
    COUNTIF(class = "TP") AS TP,
    COUNTIF(class = "TN") AS TN,
    COUNTIF(class = "FP") AS FP,
    COUNTIF(class = "FN") AS FN,
  FROM inspect, UNNEST(ins)
  GROUP BY threshold
)
SELECT INDICATOR(TP, TN, FP, FN).*
FROM confusion_matrix

実行結果

positive accuracy precision recall fallout f1
1 1.0 0.4 0.4 1.0 1.0 0.5714285714285715
2 0.8 0.6 0.5 1.0 0.6666666666666666 0.6666666666666666
3 0.6 0.8 0.6666666666666666 1.0 0.3333333333333333 0.8
4 0.2 0.8 1.0 0.5 0.0 0.6666666666666666

おまけ3(AUC)

ここまできたらROCやPRのAUCを求めてしまえる。ここまでSQLでやるのかは疑問だが。

CREATE TEMP FUNCTION AUC(arr ARRAY<STRUCT<v1 FLOAT64, v2 FLOAT64>>) AS ((
-- AUCを短冊積分で計算する
  SELECT SUM(v)
  FROM (
    SELECT v2 * (IFNULL(LEAD(v1) OVER (ORDER BY i), 1) - v1) AS v
    FROM UNNEST((
       SELECT ARRAY_AGG(STRUCT(v1, v2) ORDER BY v2)
       FROM UNNEST(arr)
    )) WITH OFFSET i
  )
));

v1に横軸、v2に縦軸を入れるとそのAUCを計算する。ちょっとあってるか自信はない。

WITH inspect AS (
  SELECT score, INSPECT(actual, score, 5) AS ins
  FROM UNNEST(ARRAY<STRUCT<actual BOOL, score FLOAT64>>[
    (false, 0.47),
    (false, 0.27),
    (true, 0.68),
    (true, 0.93),
    (false, 0.71)
  ])
)
, confusion_matrix AS (
  SELECT 
    threshold,
    COUNTIF(class = "TP") AS TP,
    COUNTIF(class = "TN") AS TN,
    COUNTIF(class = "FP") AS FP,
    COUNTIF(class = "FN") AS FN,
  FROM inspect, UNNEST(ins)
  GROUP BY threshold
)
, indicat AS (
  SELECT INDICATOR(TP, TN, FP, FN).*
  FROM confusion_matrix
)
SELECT
  AUC(ARRAY_AGG(STRUCT(fallout, recall))) AS roc_auc,
  AUC(ARRAY_AGG(STRUCT(recall, precision))) AS pr_auc,
FROM indicat
roc_auc pr_auc
1 0.5 0.16666666666666669

データが多いならINSPECTcを増やすと短冊が増えるのでより正確になると思われる。