PythonとGraphvizで二分探索木を描画する


こんにちは!

今回はGraphvizを使って、要素に重複のある二分探索木の描画をしてみました。

二分探索木とは?

↓ の画像のように、親ノードより小さいものは親ノードの左下に、大きいものは右下に配置していった二分木のこと。

親ノードと同じ値のものはどちらでもいいですが、どちら側に配置するかは統一しないといけません。

※画像はWikipediaの二分探索木の項目より

Graphvizとは?

グラフを描画してくれるツールパッケージです。本来はDOTという言語を用いて記述するのですが、

Python上でも書けるようにしたgraphvizというライブラリがあるので今回はそれを利用しました。

コード

Graphvizでは同じ値のノードは勝手に一つにまとめられてしまいます。
今回はそれを避けるためのコードも入れました。

※追記
@shiracamus さんのコメントをもとに大幅に(ほぼすべて)書き直しました。ほとんどコピペになってしまった...
せめてもの抵抗(?)としてコメントで解説みたいなのをつけました。


from graphviz import Digraph


class Node:

    def __init__(self, value, depth=0):
        self.value = value
        self.left = None
        self.right = None
        self.depth = depth

    #再帰的にすべての要素が返される。
    #if文はオブジェクトの中身が存在するならTrue,存在しないならFalseを返す
    #という仕組みがある
    def __iter__(self):
        if self.left:
            yield from self.left
        yield self
        if self.right:
            yield from self.right


class BinarySearchTree:

    def __init__(self, values):
        self.root = None
        for value in values:
            self.insert(value)

    def insert(self, value):
        parent_node = None
        #グラフ内を走査しているときの現在地を表すノード
        tmp_node = self.root
        depth = 0
        #ノードを追加する場所を探索
        while tmp_node:
            parent_node = tmp_node
            tmp_node = tmp_node.left if value < tmp_node.value else tmp_node.right
            depth += 1
        #「parent_nodeではないなら」
        #ではなく
        #「parent_nodeが存在しないなら」
        if not parent_node:
            self.root = Node(value, depth)
        else:
            if value < parent_node.value:
                parent_node.left = Node(value, depth)
            else:
                parent_node.right = Node(value, depth)

    def __iter__(self):
        #main関数の下三行でこれが呼び出されている。
        #self.rootはNodeであるので上のNodeクラスの__iter__()がこの後呼び出される
        if self.root:
            yield from self.root

    def print_text(self):
        depth = -1
        #in の後ではsortedにより第一位に深さ、第二位に値、の順に並べられたBinarySearchTreeオブジェクトを
        #リスト化している
        for node in sorted(self, key=lambda node: (node.depth, node.value)):
            if depth != node.depth:
                depth = node.depth
                #print("深さ={}".format{depth})と同じ。古いpythonのバージョンでは
                #以下の書き方ではエラーが出るので注意。
                print(f"深さ={depth}")
            print(f"\t\t{node.value}")

    def view_graph(self):
        graph = Digraph(format="png")
        #tree(BinSearchTreeクラス)がselfである。
        for node in self:
            #値が同じノードの判別のために深さを組み合わせてノードの名前を決定する
            name = f"{node.value} of {node.depth}"
             #Graphvizでノードを追加
            graph.node(name, str(node.value))
            if node.left:
                #値が同じノードの判別のために深さを組み合わせてノードの名前を決定する
                child = f"{node.left.value} of {node.left.depth}"
                #Graphvizで辺を追加
                graph.edge(name, child, label="left", color="blue")
            if node.right:
                child = f"{node.right.value} of {node.right.depth}"
                graph.edge(name, child, label="right", color="red")
        graph.view("bin_tree")


def main():
    A = [32, 75, 30, 31, 65, 5, 435, 4, 5, 43, 523, 5, 534, 5, 43, 534, 5, 43, 534, 5]
    tree = BinarySearchTree(A)
    tree.print_text()
    tree.view_graph()

if __name__ == "__main__":
    main()

実行結果

上がGraphvizで出力された画像、下がターミナルに表示された画像です。

本当はちゃんと右側に大きいノード、左側に小さいノードが来るようにしたかったのですが

僕の調べた限りではGraphvizで方向を指定することはできなかったようなので、↑ の画像のように

辺の色とラベルで方向を表現しました。

知っている方がいましたら教えていただけると幸いです。

※追記
上記の通り解決しました!


辺の色とラベルが必要ない場合は

self.graph.edge()

の引数 label と color を消してもらえれば普通の黒い辺、ラベルなしのグラフが出力されます。

こんな感じ ↓

参考にしたサイト




おわり!