Spark MLlib (ML Pipeline) の決定木モデルの木構造をダンプするコード
決定木はそのロジック的にモデルの解釈がしやすくて、意外と現実世界で利用されることが多いような気がします。
そんなわけで、Spark MLlib (ML Pipeline) の 決定木 のモデルについて、その構造を標準出力する機能を作ったのでメモとして残して置きます。
package biz.k11i.spark.misc
import org.apache.spark.ml.classification.DecisionTreeClassificationModel
import org.apache.spark.ml.tree.{CategoricalSplit, ContinuousSplit, InternalNode, LeafNode}
/**
* spark.ml の DecisionTreeClassificationModel の木構造を標準出力に書き出す。
*/
object DecisionTreePrinter {
def printTree(model: DecisionTreeClassificationModel): Unit = {
model.rootNode match {
case node: InternalNode => printNodes(node, 0)
case leaf: LeafNode => printLeaf(leaf, 0)
}
}
def printNodes(node: InternalNode, numIndents: Int): Unit = {
val indents = " " * numIndents
node.split match {
case cat: CategoricalSplit => println(s"${indents}category, featureIndex=${cat.featureIndex}, left=${cat.leftCategories.mkString(",")}, right=${cat.rightCategories.mkString(",")}")
case con: ContinuousSplit => println(s"${indents}continuous, featureIndex=${con.featureIndex}, threshold=${con.threshold}")
}
Seq(node.leftChild, node.rightChild).foreach {
case internalNode: InternalNode => printNodes(internalNode, numIndents + 1)
case leafNode: LeafNode => printLeaf(leafNode, numIndents + 1)
}
}
def printLeaf(node: LeafNode, indent: Int): Unit = {
val indents = " " * indent
println(s"${indents}prediction=${node.prediction}")
}
}
これにちょこっと手を加えれば、Random forest のモデルも Gradient-boosted trees のモデルも同様にダンプできると思うけど、面倒くさいのでまた今度。
Author And Source
この問題について(Spark MLlib (ML Pipeline) の決定木モデルの木構造をダンプするコード), 我々は、より多くの情報をここで見つけました https://qiita.com/komiya_atsushi/items/b8906a8b0147bbcf6c52著者帰属:元の著者の情報は、元のURLに含まれています。著作権は原作者に属する。
Content is automatically searched and collected through network algorithms . If there is a violation . Please contact us . We will adjust (correct author information ,or delete content ) as soon as possible .