Pandas の groupby で処理してグラフを描く


はじめに

Kaggle の Titanic で遊び始めているが, 欠損値の補完やハイパーパラメータの見直しの前に, まずデータをしっかり見ようと思い, データを眺めている. 読み込んだデータを, 例えば Survived の値でグルーピングしてグラフを描くということをササッとやりたいのだが, なかなかうまくいかない. Pandas の "GroupBy" の理解が不十分だからだ.

ネットには, 先人たちのグラフ描画の例がたくさんあるが, 私の理解の道筋を記すことで, 初心者の役に立てるのではないか? と思って, この記事を書く.

目指すゴール

下記のようなグラフを描くこと.

このグラフは, 横軸が Ticket の記号, 縦軸が生存 (s), 死亡 (d), 不明 (na) の人数を積み上げたもので, 合計人数で降順にソートしている. 例えば, 一番左端の CA. 2343 のチケット記号は, 合計 11 人, 不明が 4 名, 残りの 7 名が死亡となっている.

こんなグラフをササッと描きたい.

データを読み込む

データを読み込んで, Ticket のデータで, 同じ記号ごとの数を調べる.

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

train_data = pd.read_csv("../train.csv")
test_data = pd.read_csv("../test.csv")
total_data = pd.concat([train_data, test_data]) # train_data と test_data を連結

ticket_freq = total_data["Ticket"].value_counts()
CA. 2343        11
CA 2144          8
1601             8
S.O.C. 14879     7
3101295          7
                ..
350404           1
248706           1
367655           1
W./C. 14260      1
350047           1
Name: Ticket, Length: 929, dtype: int64

CA. 2343 が 11 人, CA 2144 が 8 人, などが分かる.

グラフ用のデータを作る

groupby でグループ化

まず, total_data をチケット記号でグルーピングする.

total_data_ticket = total_data.groupby("Ticket")
# 出力
<pandas.core.groupby.generic.DataFrameGroupBy object at 0x000001F5A14327C8>

groupby の欠点は, データの中身を表示してくれないことだ. ここは, グループ化された と頭の中で理解して, 次へ行く.

生存情報だけ取り出す

次に, 生存情報 (Survived) を取り出す.

total_data_ticket = total_data.groupby("Ticket")["Survived"]
total_data_ticket
# 出力
<pandas.core.groupby.generic.SeriesGroupBy object at 0x000001F5A1437B48>

ここでもデータは表示してくれない.

生存, 死亡, 不明ごとに数を数える

引き続き, value_counts() を使って, Survived の値ごとの数を数える. dropna=False とすることで, N/A もカウントする.

total_data_ticket = total_data.groupby("Ticket")["Survived"].value_counts(dropna=False)
total_data_ticket
# 出力
Ticket       Survived
110152       1.0         3
110413       1.0         2
             0.0         1
110465       0.0         2
110469       NaN         1
                        ..
W.E.P. 5734  NaN         1
             0.0         1
W/C 14208    0.0         1
WE/P 5735    0.0         1
             1.0         1
Name: Survived, Length: 1093, dtype: int64

データの形を変える

グラフを描くために, 生存, 死亡, 不明のデータが列方向に並ぶようなデータに変える. 使うのは unstack().

total_data_ticket = total_data.groupby("Ticket")["Survived"].value_counts(dropna=False).unstack()
total_data_ticket
# 出力
Survived    NaN 0.0 1.0
Ticket          
110152  NaN NaN 3.0
110413  NaN 1.0 2.0
110465  NaN 2.0 NaN
110469  1.0 NaN NaN
110489  1.0 NaN NaN
... ... ... ...
W./C. 6608  1.0 4.0 NaN
W./C. 6609  NaN 1.0 NaN
W.E.P. 5734 1.0 1.0 NaN
W/C 14208   NaN 1.0 NaN
WE/P 5735   NaN 1.0 1.0
929 rows × 3 columns

グラフを描く

N/A を数字に変える

上の出力を見ると, 値に NaN がまだ残っている. そこで NaN を 0 にする.

total_data_ticket.fillna(0, inplace=True)
total_data_ticket
# 出力
Survived    NaN 0.0 1.0
Ticket          
110152  0.0 0.0 3.0
110413  0.0 1.0 2.0
110465  0.0 2.0 0.0
110469  1.0 0.0 0.0
110489  1.0 0.0 0.0
... ... ... ...
W./C. 6608  1.0 4.0 0.0
W./C. 6609  0.0 1.0 0.0
W.E.P. 5734 1.0 1.0 0.0
W/C 14208   0.0 1.0 0.0
WE/P 5735   0.0 1.0 1.0
929 rows × 3 columns

列名を変える

列名が NaN, 0.0, 1.0 となっているが, これでは扱いにくいので, 列名を変える.

total_data_ticket.columns = ["nan", "d", "s"]
total_data_ticket
# 出力
    nan d   s
Ticket          
110152  0.0 0.0 3.0
110413  0.0 1.0 2.0
110465  0.0 2.0 0.0
110469  1.0 0.0 0.0
110489  1.0 0.0 0.0
... ... ... ...
W./C. 6608  1.0 4.0 0.0
W./C. 6609  0.0 1.0 0.0
W.E.P. 5734 1.0 1.0 0.0
W/C 14208   0.0 1.0 0.0
WE/P 5735   0.0 1.0 1.0
929 rows × 3 columns

行ごとの合計人数を計算する

合計人数で降順にソートしたいので, 合計人数を計算して, 新しい列に保存する. 合計を計算するには sum() を使うが, 列方向に計算するので sum(axis=1) としている.

total_data_ticket["count"] = total_data_ticket.sum(axis=1)
total_data_ticket
#出力
    nan d   s   count
Ticket              
110152  0.0 0.0 3.0 3.0
110413  0.0 1.0 2.0 3.0
110465  0.0 2.0 0.0 2.0
110469  1.0 0.0 0.0 1.0
110489  1.0 0.0 0.0 1.0
... ... ... ... ...
W./C. 6608  1.0 4.0 0.0 5.0
W./C. 6609  0.0 1.0 0.0 1.0
W.E.P. 5734 1.0 1.0 0.0 2.0
W/C 14208   0.0 1.0 0.0 1.0
WE/P 5735   0.0 1.0 1.0 2.0
929 rows × 4 columns

これで, グラフを描く準備は整った.

グラフを描く

人数の領域を決めて, 降順にソートする

まずコードを示して, 順番に説明する.

total_data_ticket[total_data_ticket["count"] > 3].sort_values("count", ascending=False)[["nan", "d", "s"]].plot.bar(figsize=(15,10),stacked=True)
コード 内容
total_data_ticket[total_data_ticket["count"] > 3] "count" が 3 より大きいデータ
.sort_values("count", ascending=False) "count" で降順にソート
[["nan", "d", "s"]] 左記の 3 つの列だけ取り出す ("count" はお役御免)
.plot.bar(figsize=(15,10),stacked=True) 棒グラフを描く. サイズを指定し, 積み上げ方式にした

これで, 冒頭に示したグラフが書ける.

これを見ると, CA. 2343CA 2144 の人は Survived = 0 かな…とか想像できる.

全体のコード

最後に全体のコードを示す.

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

train_data = pd.read_csv("../train.csv")
test_data = pd.read_csv("../test.csv")
total_data = pd.concat([train_data, test_data])

ticket_freq = total_data["Ticket"].value_counts()
ticket_freq

total_data_ticket = total_data.groupby("Ticket")["Survived"].value_counts(dropna=False).unstack()

total_data_ticket.fillna(0, inplace=True)
total_data_ticket.columns = ["nan", "d", "s"]
total_data_ticket["count"] = total_data_ticket.sum(axis=1)
total_data_ticket[total_data_ticket["count"] > 3].sort_values("count", ascending=False)[["nan", "d", "s"]].plot.bar(figsize=(15,10),stacked=True)

おわりに

この手法を使って, EmbarkedCabin, Name の苗字や敬称など, 他の非数値データも確認していく.

参考