Athenaの実行結果をダウンロードして読み込むモジュール


AWSのAthenaに対して実行結果を取得するバッチ処理を書いていました。その際に、以下のようなプログラムを書けたら楽だと考えていたのですが、boto3はAPI単位の処理しか書かれていません。

というわけで、多少雑ですが自前で実装しました。

main.py
from lib import athena
import csv

# 一時ファイルの保存先のs3のパスを指定
client = athena.AthenaClient('s3://ninomiyt-example-bucket/athena/')

# SQLを実行し、結果のファイルをダウンロードする
job = client.execute('SELECT * FROM table_name', 'database_name')
filepath = job.fetch('/tmp/athena/', sleep_sec=5)

# ダウンロードしたcsvの読み込み
with filepath.open('r') as f:
    reader = csv.DictReader(f)
    for row in reader:
        pass
        # 行ごとの処理

以下がそのプログラムです。

lib/athena.py
import csv
from pathlib import Path
from time import sleep
import boto3
from retry import retry
from . import s3


REGION_NAME = 'ap-northeast-1'


class AthenaClient(object):
    """Athenaのクライアント"""

    def __init__(self, output_location):
        """Args:
            output_location (str): 出力先を表す"s3://{バケット名}/{パス}"の文字列

        """
        self.athena = boto3.client('athena', region_name=REGION_NAME)
        self.output_location = output_location

    @retry(tries=4, delay=5)
    def execute(self, query_string, database):
        """Athenaに対してSQLを実行する。

        Args:
            query_string (str): 実行するSQL
            database (str): 実行対象のDB名
        Returns:
            QueryExecutionJob: クエリの実行結果のオブジェクト

        """
        res = self.athena.start_query_execution(
            QueryString=query_string,
            QueryExecutionContext={'Database': database},
            ResultConfiguration={'OutputLocation': self.output_location}
        )
        return QueryExecutionJob(self.athena, res)


class QueryExecutionJob(object):
    """各クエリの実行結果を表すクラス"""

    def __init__(self, athena, query_exec):
        """Args:
            athena (boto3.client): クエリを実行したboto3のクライアント
            query_exec (Dict): start_query_executionのAPIの戻り値

        """
        self.athena = athena
        self.query_exec_id = query_exec['QueryExecutionId']

    def fetch(self, localdir, *, sleep_sec):
        """実行結果をローカルにダウンロードする。

        Args:
            localdir (str/pathlib.Path): 保存先のディレクトリ
            sleep_sec (int/float): クエリが完了していない場合に待つ秒数
        Returns:
            pathlib.Path: ローカルにダウンロードしたテキストファイルの情報
        Raises:
            QueryExecutionError: クエリが失敗した場合
            通信失敗時のエラー

        """
        while not self._is_complete():
            sleep(sleep_sec)
        return self._download_csv_file(localdir)

    def _is_complete(self):
        """クエリが完了しているか調べる

        Returns:
            bool: クエリが完了していればTrue
        Raises:
            AthenaExecutionError: クエリが失敗した場合
            通信失敗時のエラー

        """
        state = self._get_query_execution_state()
        if state in ('FAILED', 'CANCELLED'):
            raise QueryExecutionError(
                f'{self.query_exec_id} のジョブがState: {state} で失敗しました')
        return state == 'SUCCEEDED'

    def _download_csv_file(self, localdir):
        """S3の実行結果のファイルをダウンロードする
        保存先のディレクトリが存在しない場合は作成する

        Args:
            localdir(str/pathlib.Path): 保存先のディレクトリ
        Returns:
            pathlib.Path: ローカルにダウンロードしたテキストファイルの情報
        Raises:
            通信失敗時のエラー

        """
        s3_location = self._get_output_location()
        localpath = Path(localdir) / s3_location.split('/')[-1]
        # 保存先を/tmp/以下で、元々ディレクトリが存在しないことも想定しています
        localpath.parent.mkdir(parents=True, exist_ok=True)
        s3.download_file(s3_location, localpath)
        return localpath

    @retry(tries=4, delay=5)
    def _get_query_execution_state(self):
        """クエリの実行状況を返す。

        Returns:
            str: クエリの実行状況 (RUNNING|SUCCEEDED|FAILED|CANCELLED)

        """
        res = self.athena.get_query_execution(
            QueryExecutionId=self.query_exec_id)
        return res['QueryExecution']['Status']['State']

    @retry(tries=4, delay=5)
    def _get_output_location(self):
        """アウトプットCSVファイルの場所を取得する。

        Returns:
            str: s3://{バケット名}/{パス} 形式の文字列
        """
        res = self.athena.get_query_execution(
            QueryExecutionId=self.query_exec_id)
        return res['QueryExecution']['ResultConfiguration']['OutputLocation']


class QueryExecutionError(Exception):
    """クエリが失敗・中止されたことを表すエラー"""
    pass
lib/s3.py
import csv
from urllib.parse import urlparse
import boto3
from retry import retry


@retry(tries=4, delay=5)
def download_file(s3_path, localpath):
    """S3のファイルをローカルにダウンロードする。

    Args:
        s3_path (str): S3にあるファイルのパス
        localpath (str/pathlib.Path): ローカルのファイルパス

    """
    s3 = boto3.resource('s3')
    bk, key = _split_s3_path(s3_path)
    bucket = s3.Bucket(bk)
    bucket.download_file(key, str(localpath))


def _split_s3_path(s3_path):
    """s3://{バケット名}/{パス}を分割する。
    こちらを参考にしました。
    https://stackoverflow.com/questions/42641315/s3-urls-get-bucket-name-and-path

    Args:
        s3_path(str): s3://{バケット名}/{パス}の文字列
    Returns:
        tuple: 以下の値を返す
            - str: バケット名
            - str: S3上のパス

    """
    parsed = urlparse(s3_path)
    return parsed.netloc, parsed.path.strip('/')

今回実装していて、改めてpathlibの便利さを感じました。

「元々保存先のローカルファイルパスにファイルがある場合」などの例外ケースは考えられていないので参考にする場合は注意してください。また、リトライ処理の回数などは適当に決めてしまっています。

簡単なライブラリとして切り出してみても面白いかなと思ったのですが、「BigQueryでも同じようなライブラリ欲しい」とか「Athenaであると便利な他の機能も入れたい」とか、いろいろ考えてしまうので一旦ブログの形で出しておきます。