SparkでDataFrameを使ってグループごとの累積和・総和を求める方法【Python版】


Sparkのpython版DataFrameのWindow関数を使って、カラムをグルーピング&ソートしつつ、累積和を計算するための方法です。

公式のPython APIドキュメントを調べながら模索した方法なので、もっと良い方法があるかも。
使ったSparkのバージョンは1.5.2です。

サンプル・データ

PostgreSQLのテーブルにテスト用データを用意し、pysparkにDataFrameとしてロードします。

$ SPARK_CLASSPATH=postgresql-9.4-1202.jdbc41.jar PYSPARK_DRIVER_PYTHON=ipython pyspark
(..snip..)
In [1]: df = sqlContext.read.format('jdbc').options(url='jdbc:postgresql://localhost:5432/postgres?user=postgres', dbtable='public.foo').load()
(..snip..)
In [2]: df.printSchema()
root
 |-- a: integer (nullable = true)
 |-- b: timestamp (nullable = true)
 |-- c: integer (nullable = true)
In [4]: df.show()
+---+--------------------+---+
|  a|                   b|  c|
+---+--------------------+---+
|  1|2015-11-22 10:00:...|  1|
|  1|2015-11-22 10:10:...|  2|
|  1|2015-11-22 10:20:...|  3|
|  1|2015-11-22 10:30:...|  4|
|  1|2015-11-22 10:40:...|  5|
|  1|2015-11-22 10:50:...|  6|
|  1|2015-11-22 11:00:...|  7|
|  1|2015-11-22 11:10:...|  8|
|  1|2015-11-22 11:20:...|  9|
|  1|2015-11-22 11:30:...| 10|
|  1|2015-11-22 11:40:...| 11|
|  1|2015-11-22 11:50:...| 12|
|  1|2015-11-22 12:00:...| 13|
|  2|2015-11-22 10:00:...|  1|
|  2|2015-11-22 10:10:...|  2|
|  2|2015-11-22 10:20:...|  3|
|  2|2015-11-22 10:30:...|  4|
|  2|2015-11-22 10:40:...|  5|
|  2|2015-11-22 10:50:...|  6|
|  2|2015-11-22 11:00:...|  7|
+---+--------------------+---+
only showing top 20 rows

カラムaがグルーピング用、カラムbがソート用、カラムcが計算対象です。

カラムグループごとの累積和

カラムaでグループ分けしつつ、カラムbでソートし、カラムcの累積和を取ります。

まずはWindowの定義

In [6]: from pyspark.sql.Window import Window

In [7]: from pyspark.sql import functions as func

In [8]: window = Window.partitionpartitionBy(df.a).orderBy(df.b).rangeBetween(-sys.maxsize,0)

In [9]: window
Out[9]: <pyspark.sql.window.WindowSpec at 0x18368d0>

このウィンドウ上でpyspark.sql.functions.sum()を計算したColumnを作成

In [10]: cum_c = func.sum(df.c).over(window)

In [11]: cum_c
Out[11]: Column<'sum(c) WindowSpecDefinition UnspecifiedFrame>

このColumnを元のDataFrameにくっつけた新しいDataFrameを作成

In [12]: mod_df = df.withColumn("cum_c", cum_c)

In [13]: mod_df
Out[13]: DataFrame[a: int, b: timestamp, c: int, cum_c: bigint]

In [14]: mod_df.printSchema()
root
 |-- a: integer (nullable = true)
 |-- b: timestamp (nullable = true)
 |-- c: integer (nullable = true)
 |-- cum_c: long (nullable = true)


In [15]: mod_df.show()
+---+--------------------+---+-----+
|  a|                   b|  c|cum_c|
+---+--------------------+---+-----+
|  1|2015-11-22 10:00:...|  1|    1|
|  1|2015-11-22 10:10:...|  2|    3|
|  1|2015-11-22 10:20:...|  3|    6|
|  1|2015-11-22 10:30:...|  4|   10|
|  1|2015-11-22 10:40:...|  5|   15|
|  1|2015-11-22 10:50:...|  6|   21|
|  1|2015-11-22 11:00:...|  7|   28|
|  1|2015-11-22 11:10:...|  8|   36|
|  1|2015-11-22 11:20:...|  9|   45|
|  1|2015-11-22 11:30:...| 10|   55|
|  1|2015-11-22 11:40:...| 11|   66|
|  1|2015-11-22 11:50:...| 12|   78|
|  1|2015-11-22 12:00:...| 13|   91|
|  2|2015-11-22 10:00:...|  1|    1|
|  2|2015-11-22 10:10:...|  2|    3|
|  2|2015-11-22 10:20:...|  3|    6|
|  2|2015-11-22 10:30:...|  4|   10|
|  2|2015-11-22 10:40:...|  5|   15|
|  2|2015-11-22 10:50:...|  6|   21|
|  2|2015-11-22 11:00:...|  7|   28|
+---+--------------------+---+-----+
only showing top 20 rows

計算できていますね。

カラムグループごとの総和

今度は、カラムaのグループごとに、カラムcの総和を計算します。
DataFrameをgroupBy()でpyspark.sql.GroupedDataにして、pyspark.sql.GroupedData.sum()を使います。
さっきのsum()とややこしいけど、こちらはColumnオプジョクトを引数に持たすとエラーが出るので注意します。

In [25]: sum_c_df = df.groupBy('a').sum('c')

また、先ほどと違ってこれはWindow関数ではないので、返ってくる結果はDataFrameです。
しかも、総和を格納したカラム名は勝手に決まります。

In [26]: sum_c_df
Out[26]: DataFrame[a: int, sum(c): bigint]

うーん、ややこしい。

とりあえず、元のDataFrameにカラムとしてくっつけます。

In [27]: mod_df3 = mod_df2.join('a'sum_c_df, 'a'()

In [28]: mod_df3.printSchema()
root
 |-- a: integer (nullable = true)
 |-- b: timestamp (nullable = true)
 |-- c: integer (nullable = true)
 |-- cum_c: long (nullable = true)
 |-- sum(c): long (nullable = true)


In [29]: mod_df3.show()
(..snip..)
+---+--------------------+---+-------+------+
|  a|                   b|  c|  cum_c|sum(c)|
+---+--------------------+---+-------+------+
|  1|2015-11-22 10:00:...|  1|      1|    91|
|  1|2015-11-22 10:10:...|  2|      3|    91|
|  1|2015-11-22 10:20:...|  3|      6|    91|
|  1|2015-11-22 10:30:...|  4|     10|    91|
|  1|2015-11-22 10:40:...|  5|     15|    91|
|  1|2015-11-22 10:50:...|  6|     21|    91|
|  1|2015-11-22 11:00:...|  7|     28|    91|
|  1|2015-11-22 11:10:...|  8|     36|    91|
|  1|2015-11-22 11:20:...|  9|     45|    91|
|  1|2015-11-22 11:30:...| 10|     55|    91|
|  1|2015-11-22 11:40:...| 11|     66|    91|
|  1|2015-11-22 11:50:...| 12|     78|    91|
|  1|2015-11-22 12:00:...| 13|     91|    91|
|  2|2015-11-22 10:00:...|  1|      1|    91|
|  2|2015-11-22 10:10:...|  2|      3|    91|
|  2|2015-11-22 10:20:...|  3|      6|    91|
|  2|2015-11-22 10:30:...|  4|     10|    91|
|  2|2015-11-22 10:40:...|  5|     15|    91|
|  2|2015-11-22 10:50:...|  6|     21|    91|
|  2|2015-11-22 11:00:...|  7|     28|    91|
+---+--------------------+---+-------+------+
only showing top 20 rows

うまくグループごとの総和が計算できていますね。

カラムグループごとの(総和 - 累積和)

では、カラムcについて総和までの残り値を計算しましょう。つまり、総和 - 累積和です。

In [30]: diff_sum_c = mod_df3[('sum(c)'] - mod_df3['cum_c']

In [31]: mod_df4 = mod_df3.withColumn("diff_sum_c", diff_sum_c)

In [34]: mod_df4.show()
(..snip..)
+---+--------------------+---+-------+------+----------+
|  a|                   b|  c|cum_c_2|sum(c)|diff_sum_c|
+---+--------------------+---+-------+------+----------+
|  1|2015-11-22 10:00:...|  1|      1|    91|        90|
|  1|2015-11-22 10:10:...|  2|      3|    91|        88|
|  1|2015-11-22 10:20:...|  3|      6|    91|        85|
|  1|2015-11-22 10:30:...|  4|     10|    91|        81|
|  1|2015-11-22 10:40:...|  5|     15|    91|        76|
|  1|2015-11-22 10:50:...|  6|     21|    91|        70|
|  1|2015-11-22 11:00:...|  7|     28|    91|        63|
|  1|2015-11-22 11:10:...|  8|     36|    91|        55|
|  1|2015-11-22 11:20:...|  9|     45|    91|        46|
|  1|2015-11-22 11:30:...| 10|     55|    91|        36|
|  1|2015-11-22 11:40:...| 11|     66|    91|        25|
|  1|2015-11-22 11:50:...| 12|     78|    91|        13|
|  1|2015-11-22 12:00:...| 13|     91|    91|         0|
|  2|2015-11-22 10:00:...|  1|      1|    91|        90|
|  2|2015-11-22 10:10:...|  2|      3|    91|        88|
|  2|2015-11-22 10:20:...|  3|      6|    91|        85|
|  2|2015-11-22 10:30:...|  4|     10|    91|        81|
|  2|2015-11-22 10:40:...|  5|     15|    91|        76|
|  2|2015-11-22 10:50:...|  6|     21|    91|        70|
|  2|2015-11-22 11:00:...|  7|     28|    91|        63|
+---+--------------------+---+-------+------+----------+
only showing top 20 rows

補足

今回気付きましたが、SPARK_CLASSPATHを使うのはSpark 1.0以上では推奨されていないようです。
pyspark起動時に以下のようなメッセージが出ました。

15/11/22 12:32:44 WARN spark.SparkConf: 
SPARK_CLASSPATH was detected (set to 'postgresql-9.4-1202.jdbc41.jar').
This is deprecated in Spark 1.0+.

Please instead use:
 - ./spark-submit with --driver-class-path to augment the driver classpath
 - spark.executor.extraClassPath to augment the executor classpath

どうも、クラスタを利用する場合には異なるサーバでこの環境変数が正しく伝わらないため、別のパラメータを使うことが推奨されているようです。

うむむ。
こういうローカルと分散環境の違い、きちんと把握していかないとなぁ。