機械学習の可視化ツールTensorboardを使ってみた


はじめに

Tensorboardを初めて使いグラフを書きその便利さに感動したので、共有します。
Deep LearningのフレームワークはPyTorchを使用しました。

インストール

Anacondaを使っているので以下のコマンドでTensorboardをインストールします。

conda install tensorboard

コーディング

tb.py
import numpy as np
from torch.utils.tensorboard import SummaryWriter#グラフを書くSummaryWriterをimport

np.random.seed(1000)

x = np.random.randn(1000)

writer = SummaryWriter(log_dir="./logs")#インスタンス生成 保存するディレクトリも指定

for i in range(1000):
    writer.add_scalar("x", x[i], i)#値を書き込む
    writer.add_scalar("sin", np.sin(i), i)

writer.close()#閉じる

ファイル名をtensorboard.pyにするとmoduleと被りImportErrorになるので注意しましょう。

解説

簡単にいうと、上のコードでは、ランダムな値を持つ配列とsin関数をプロットしています。

SummaryWriterをimport

from torch.utils.tensorboard import SummaryWriter
Tensorboardでグラフの描画に必要なmoduleであるSummaryWriterをimportします。

インスタンス生成

writer = SummaryWriter(log_dir="./logs")
これはカレントディレクトリにlogsディレクトリを作成し、そのlogsの中にTensorboard用のファイルが保存されます。

値を代入

writer.add_scalar("x", x[i], i)で配列の値を入れます。
writer.add_scalar(tags, scalar_value, global_step)となっており、tagsでグラフの名前を指定して、scalar_valueで保存する値を代入、global_stepでグラフの横軸の間隔を指定します。

閉じる

writer.close()で最後に閉じましょう。

グラフをみる

tb.pyの実行

上記のコードを実行しましょう。グラフが描画されます。

python tb.py

グラフをみる

以下のコマンドを実行しましょう。--logdir=""で保存したディレクトリを指定しましょう。
今回は./logsです。

tensorboard --logdir="./logs"

そうすると、以下の文がターミナルに出力されます。

TensorBoard 2.2.1 at http://localhost:8000/ (Press CTRL+C to quit)

ローカルサーバーが立ち上がるので、ブラウザにhttp://localhost:8000/と打ちましょう。

chromeでみると、グラフが綺麗にプロットしていることがわかります。

ssh先のグラフをみる

Deep Learningのコードは計算量が多くローカルPC(手元のPC)では莫大な時間がかかるので、
研究室にあるサーバーのGPUにsshしてサーバー上でコードを回すことがデフォルトです。
では、そういう場合はリモートサーバーで描画したグラフをローカルPCでどうやってみるのでしょうか?

リモートサーバーにsshする

sshをする時に、-Lオプションを用いてクライアント(ローカルPC)のlocalhost:9000をリモートサーバーのユーザ名@サーバーのIPアドレス:8000に繋げます。

@ローカルPC
ssh ユーザ名@サーバーのIPアドレス -L 9000:localhost:8000

リモートサーバーでtb.pyの実行

sshしたリモートサーバーでグラフを描画するコードを実行しましょう。

@リモートサーバー
python tb.py

Tensorboardの実行

sshしたリモートサーバーでグラフをみるためのコマンドを実行しましょう。
sshした際にローカルPCに繋いだポートは8000なので、--portオプションで8000を指定して実行しましょう。

@リモートサーバー
tensorboard --logdir="./logs" --port 8000

以下のような文が出力されます。

@リモートサーバー
TensorBoard 2.2.1 at http://localhost:8000/ (Press CTRL+C to quit)

グラフをみる

さっきはhttp://localhost:8000/をブラウザに入力したら、グラフがみれましたが今回は見れません。

今回はリモートサーバーのポート8000とローカルPCのポート9000を繋げたので、
ローカルPCのブラウザでhttp://localhost:9000/と入力すれば、さっきと同じグラフが見れます。

まとめ

PyTorchでTensorboardを用いてグラフを描画しました。
また、ssh先のリモートサーバーでまわしたコードのグラフをローカルPCでみる方法を紹介しました。
私もこのTensorboardとssh -Lを利用してDeep Learningに活用していきたいと思います。