amazon athenaでhash値を使ったtrain/test splitをした話


3行で

・amazon athenaにクエリを投げて、hash値を使ったtrain/test splitをしたい
・amazon athenaのhash関数は戻り値がbinaryで、binaryを扱える関数が貧弱すぎてかなりハマった
・文字列にはTO_UTF8関数、XXHASH64関数、FROM_BIG_ENDIAN_64関数を順次適用するとhash値が得られる。

モチベーション

データのリークを避けつつtrain/testにデータを分けるのに有効な方法として、hash値を使うものがある。
すべてのレコードをランダムで分けるだけでは、train/testにリークが発生する可能性がある。
(例えば、小売店の取引履歴があったとして、同じIDの人が何度か取引したデータが別のレコードとして扱われていれば、それらがtrain/test間のリークを起こす)
hash値は引数に対して一意的な値を与えられる(と仮定していい)ので、train/testにまたがってほしくないデータは同じhash値になるように同じ値をhash関数に渡す。
そのhash値をもとに分割すればいい。(例えばhash値のmodをとるなどして)
当然アプリケーション側のコードでも実現できることだが、DB側でもできるようにしたかった。
(データが大きくなったときに有利になったりしないかな・・・)

内容

結論

ナイーブなやり方ではうまく行かないし、うまく行かないやり方を紹介しようと思うが、結論の前にあまりごちゃごちゃと書くのも読みにくいのでまずはうまくいくコードを紹介する。

query
SELECT FROM_BIG_ENDIAN_64(XXHASH64(CAST('Amazon Redshift' AS varbinary))) AS hash

結果は以下

hash
-5401464031556620364

例えば、これでtrain/test splitをやるなら、

query
SELECT 
  Your_feature
  FROM_BIG_ENDIAN_64(XXHASH64(CAST(<your_key>))) AS hash
FROM
  Yourdatabase.Yourtable
WHERE
  MOD(ABS(FROM_BIG_ENDIAN_64(XXHASH64(CAST(<your_key>)))),10) < 8

などとやれば、<your_key>のカラムから重複を除いておよそ80%のレコードが得られます。

うまく行かないやり方その1 CAST AS INTEGER

僕も最初に試したやつですが、hash関数の戻り値をCAST AS INTEGERで型変換して数値にする方法です。
例えば以下のクエリはエラーです

query
SELECT CAST(MD5(CAST('Amazon Redshift' AS varbinary)) AS INTEGER)

Cannot cast varbinary to integer

と言われてしまいます。

これは、下記のドキュメントで言及されています。
https://docs.aws.amazon.com/ja_jp/athena/latest/ug/functions-operators-reference-section.html

Amazon Athena クエリエンジンは、Presto 0.172 に基づいています。これらの関数の詳細については、Presto 0.172 の関数と演算子を参照してください。
https://prestodb.github.io/docs/0.172/index.html

そしてバイナリに使える関数は以下です
https://prestodb.github.io/docs/0.172/functions/binary.html
なんであれ、ここのリストにない操作で実現しようと思わないほうが無難でしょう。

うまく行かないやり方その2 MD5をSUBSTRで短くする

よく使われるhash関数であるMD5ですが、戻り値が124ビットです。
バイナリを数値に戻す関数はFROM_BIG_ENDIAN_64ですが、124ビットのバイナリを引数にとれません。
どうせそんなに桁数なんていらないんだから、途中の桁で打ち切ってしまえ、私はそう考えました。
驚くべきことに、Presto 0.172ではSUBSTRはサポートされていません
実際、以下のクエリはエラーです

query
SELECT SUBSTR(MD5(CAST('Amazon Redshift' AS varbinary)),8)

Unexpected parameters (varbinary, integer) for function substr. Expected: substr(varchar(x), bigint, bigint) , substr(varchar(x), bigint) , substr(char(x), bigint) , substr(char(x), bigint, bigint)

というエラーを吐きます。
Presto 0.225ではサポートされているようで、そっちの情報ばっかり出てきたのでめちゃくちゃ混乱しました・・・。

もう一度結論

https://prestodb.github.io/docs/0.172/functions/binary.html
というわけで、上記リストにある関数からhash値取得と数値に変換ができる組み合わせを選ぶと
hash関数: XXHASH64
変換: FROM_BIG_ENDIAN_64
ということになるわけです。

教訓

困ったら公式ドキュメントを・・・