Spark MLlib (ML Pipeline) の決定木モデルの木構造をダンプするコード


決定木はそのロジック的にモデルの解釈がしやすくて、意外と現実世界で利用されることが多いような気がします。

そんなわけで、Spark MLlib (ML Pipeline) の 決定木 のモデルについて、その構造を標準出力する機能を作ったのでメモとして残して置きます。

package biz.k11i.spark.misc

import org.apache.spark.ml.classification.DecisionTreeClassificationModel
import org.apache.spark.ml.tree.{CategoricalSplit, ContinuousSplit, InternalNode, LeafNode}

/**
  * spark.ml の DecisionTreeClassificationModel の木構造を標準出力に書き出す。
  */
object DecisionTreePrinter {
  def printTree(model: DecisionTreeClassificationModel): Unit = {
    model.rootNode match {
      case node: InternalNode => printNodes(node, 0)
      case leaf: LeafNode => printLeaf(leaf, 0)
    }
  }

  def printNodes(node: InternalNode, numIndents: Int): Unit = {
    val indents = "  " * numIndents

    node.split match {
      case cat: CategoricalSplit => println(s"${indents}category, featureIndex=${cat.featureIndex}, left=${cat.leftCategories.mkString(",")}, right=${cat.rightCategories.mkString(",")}")
      case con: ContinuousSplit => println(s"${indents}continuous, featureIndex=${con.featureIndex}, threshold=${con.threshold}")
    }

    Seq(node.leftChild, node.rightChild).foreach {
      case internalNode: InternalNode => printNodes(internalNode, numIndents + 1)
      case leafNode: LeafNode => printLeaf(leafNode, numIndents + 1)
    }
  }

  def printLeaf(node: LeafNode, indent: Int): Unit = {
    val indents = "  " * indent
    println(s"${indents}prediction=${node.prediction}")
  }
}

これにちょこっと手を加えれば、Random forest のモデルも Gradient-boosted trees のモデルも同様にダンプできると思うけど、面倒くさいのでまた今度。