Rで決定木を試してみた、そしてクロス集計で結果を再確認してみた


決定木をより良く理解するために、Rで決定木を実行した。
そして数値がどのように決定木で使用されているのかを理解するためにスプレッドシートで関連する数値を集計してみて確認した。

データの用意が面倒なので、ディフォルトのタイタニック沈没のデータを使ってます。
Rのコードは下記の通り

 install.packages("rpart")
 install.packages("rpart.plot")
 install.packages("partykit")
 library(rpart)
 library(rpart.plot)
 library(partykit)

 tmp <- data.frame(Titanic)
 df <- data.frame(
  Class = rep(tmp$Class, tmp$Freq),
  Sex = rep(tmp$Sex, tmp$Freq),
  Age = rep(tmp$Age, tmp$Freq),
  Survived = rep(tmp$Survived, tmp$Freq)
)
 head(df)

 ct <- rpart(Survived ~ Class + Sex + Age, data = df, method = "class")
  print(ct) 

  rpart.plot(ct, type = 1, uniform = TRUE, extra = 1, under = 1, faclen = 0)  
  plot(as.party(ct))  

  write.csv(tmp, "/Users/data.csv", quote=FALSE)  

結果がこちら

生存のYesとNoはまずMaleとFemaleで一番分かれるということがわかる。
直感的な数値を見てみましょう。
やり方は下記のデータテーブルを

このようにピポットテーブルで集計してみました。
所謂クロス集計です。

これで見るとFemaleとMaleの生存率に大きな差があることがわかります。

さらに次の要因も見てみましょう。
ClassとAgeがありますが、今回はAgeを選んで集計してみます。

特に3行目と4行目に注目したいですが、Maleの中でもChildの方が生存率は45%程度あり、Adultは20%しかないことがわかります。

というようにクロス集計でも結果は得られますが、決定木を使用した方が効率が良いということを実感できました。
今回は決定木から見て集計したので楽でしたが、ピポットテーブルだと網羅的に集計しないと生存率の差は分かりませんよね。