scikit-learnのtree.plot_treeがとても簡単・便利だったので簡単に使い方をまとめてみた


Pythonで実装された機械学習ライブラリのscikit-learnは様々なアルゴリズムを簡単に試せることからしばしば利用されています。花形と言えばTensorFlowやPyTorchですが、お堅い現場ではなかなか使えません。。。そんなscikit-learnで教師有り学習の代表的な手法「決定木」を学習後に描画時に便利な関数がVersion0.21.xから実装されたので従来のGraphVizを用いる方法と比較しつつ試してみました。

従来の可視化方法: GraphVizを利用

従来はGraphVizという別のライブラリをインストールして、利用していました。結構手間が掛かります。。。

GraphVizのインストール@Mac
brew install graphviz
pip install graphviz
GraphVizのインストール@Ubuntu
sudo apt install -y graphviz
pip install graphviz
GraphVizを用いた方法
import graphviz
from sklearn import tree
from sklearn.datasets import load_iris

iris = load_iris()
clf = DecisionTreeClassifier()
clf = clf.fit(iris.data, iris.target)

graph = graphviz.Source(tree.export_graphviz(clf, class_names=iris.feature_names, filled=True))
graph

実行結果

実行結果はgraph.render('decision_tree')を実行するとPDFとして保存できます。

tree.plot_treeを利用

tree.plot_treeを用いてGraphVizを利用して描画した物と同様の図を描画してみます。scikit-learnのtreeモジュールに格納されている為、追加のインストールは不要です。(filledオプションはデフォルトではFalseですが、Trueにすると彩色されます)

tree.plot_treeを用いた方法
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier, plot_tree

iris = load_iris()
clf = DecisionTreeClassifier()
clf = clf.fit(iris.data, iris.target)
iris = load_iris()
plt.figure(figsize=(15, 10))
plot_tree(clf, feature_names=iris.feature_names, filled=True)
plt.show()

実行結果

GraphVizを用いた方法と同じ図を出力出来ました。Jupyter Notebook上で実行すれば、描画結果をそのまま右クリックして画像として保存も出来ます。

2020/11/27追記: クラス名も決定木に表示

class_nameオプションを追加すると最終的に分類されたクラス名も表示出来ます。

plt.figure(figsize=(15, 10))
plot_tree(clf, feature_names=iris.feature_names, class_names=iris.target_names, filled=True)
plt.show()

まとめ

scikit-learnのtree.plot_treeと従来のGraphVizを用いる方法を決定木の可視化に対して行い、tree.plot_treeが(従来の方法より)簡単かつ便利だと実感しました。今後積極的に活用していきたいと思います。

Reference