graphviz でシンプルな分類木を描く


はじめに

根拠の説明しやすさ,理解のしやすさと言った観点から,決定木は意思決定に当たって便利な手法の1つです.こういった中で,生成した分類基準を図として表す場合に,余計な情報が不要な場合があると思います.今回は,各種情報を除くシンプルな決定木のことを分類木と呼び,生成する手法についてまとめます.

やりたいこと

情報を減らして分類木を描画する.

例題

アヤメのデータセットに対して,深さが3のシンプルな分類木を描画する場合,以下のようになると思います.(プログラム割愛)

今回は,gini不純度や全データ数,各データ数などを除いた,シンプルな分類木の作成を目指します.

スクリプト

sklearn の tree.export_graphviz 関数で得たデータに変更を加えることで決定木の形式を変更します.

#! env python
# -*- coding: utf-8 -*-

import graphviz
from sklearn.datasets import load_iris
from sklearn import tree

def remove_gini_to_value(str, n):
    for _ in range(n):
        start, end = str.find('gini = '), str.find('class = ', str.find('gini = '))
        rm_str = str[start:end]
        str = str.replace(rm_str, '')
    return str

if __name__ == '__main__':
    iris = load_iris()
    # 決定木の作成
    clf = tree.DecisionTreeClassifier(max_depth=3)
    clf = clf.fit(iris.data, iris.target)
    # 決定木の描画
    data = tree.export_graphviz(clf, feature_names=iris.feature_names,
                            class_names=iris.target_names,impurity=True,
                            filled=False, rounded=True)
    # gini不純度,合計データ数,各データ数を削除
    data = remove_gini_to_value(data, clf.tree_.node_count)
    graph = graphviz.Source(data)
    display(graph)
    graph.render('tree',format='png')

参考

決定木の作成と可視化に当たって参考になったサイトをまとめておきます.
DecisionTreeClassifier関数の中身,パラメータ説明など
https://github.com/scikit-learn/scikit-learn/blob/2beed55847ee70d363bdbfe14ee4401438fba057/sklearn/tree/_classes.py#L612
export_graphviz関数の説明
https://github.com/scikit-learn/scikit-learn/blob/2beed5584/sklearn/tree/_export.py#L665
決定木の内部要素抽出
https://scikit-learn.org/stable/auto_examples/tree/plot_unveil_tree_structure.html