データ可視化ライブラリDashを使ってマウスホバーで画像を表示するグラフを作成する


概要

データ可視化ライブラリのDashを使って、マウスホバーで対象データの画像を表示するグラフを作成する手順を記載しています。
Dashを使用するとpythonのみかつ非常に少ないコードでこのような動的なグラフの作成ができます。
作成したグラフは以下のような感じです。

背景

画像データから抽出した高次元の特徴量を、T-SNEやUMAPを使って次元削減して散布図としてグラフを作成してデータの分布を確認するということはよくあると思います。
その際に気になるデータの画像を確認できるような動的なグラフを簡単に作成したいと考えていました。

環境

Google Colabで実施していますが、ローカルのJupyterでも実施可能です。
Jupyter内でグラフを表示していますが、数行変更すれば単独のアプリケーションとして起動することも可能です。この記事の下の補足にアプリケーションで起動する場合のコードを添付しました。

手順

  1. ライブラリをインストール
  2. ライブラリをimport
  3. 補助関数を定義
  4. グラフを作成
  5. Dashで表示

1. ライブラリをインストール

dashとjupyter_dashをインストールします。

!pip install dash
!pip install jupyter_dash

2. ライブラリをimport

利用するライブラリーをimportします。

from jupyter_dash import JupyterDash 
import dash_core_components as dcc 
import dash_html_components as html 
import plotly.express as px
from dash.dependencies import Input, Output
from sklearn.datasets import load_digits
from sklearn.manifold import TSNE
import numpy as np
from PIL import Image
from io import BytesIO
import base64

3. 補助関数を定義

numpyのarrayをbase64に変換する関数を定義します。

def numpy_to_b64(array):
    # Convert from 0-16 to 0-255
    array = np.uint8(255 - 255/16 * array)

    im_pil = Image.fromarray(array)
    buff = BytesIO()
    im_pil.save(buff, format='png')
    im_b64 = base64.b64encode(buff.getvalue()).decode('utf-8')
    return im_b64

4. グラフを作成

mnistのデータをT-SNEで2次元に次元削減します。
その結果を散布図で表示するグラフを作成します。

digits = load_digits()
tsne = TSNE(n_components=2, random_state=0)
projections = tsne.fit_transform(digits.data)
fig = px.scatter(
    projections, x=0, y=1,
    color=digits.target
)

5. Dashで表示

マウスホバーで、画像が表示されるようにCallbackを定義します。

app = JupyterDash(__name__)

app.layout = html.Div([
                       html.Div(id="output"),
                       dcc.Graph(id="fig1", figure=fig)
])

@app.callback(
    Output('output', 'children'),
    [Input('fig1', 'hoverData')])
def display_image(hoverData):
    if hoverData:
        idx = hoverData['points'][0]['pointIndex']
        im_b64 = numpy_to_b64(digits.images[idx])
        value = 'data:image/png;base64,{}'.format(im_b64)
        return html.Img(src=value, height='100px')
    return None

app.run_server(mode="inline")

以上を実行するとグラフが表示されます。

補足

アプリケーションとして実行する場合は、app.pyのようなコードになります。
python app.pyのようにプログラムを実行して、http://127.0.0.1:8050/をブラウザーで開くとグラフが表示されます。

app.py
import dash
import dash_core_components as dcc 
import dash_html_components as html 
import plotly.express as px
from dash.dependencies import Input, Output
from sklearn.datasets import load_digits
from sklearn.manifold import TSNE
import numpy as np
from PIL import Image
from io import BytesIO
import base64


def numpy_to_b64(array):
    # Convert from 0-16 to 0-255
    array = np.uint8(255. - 255./16. * array)

    im_pil = Image.fromarray(array)
    buff = BytesIO()
    im_pil.save(buff, format="png")
    im_b64 = base64.b64encode(buff.getvalue()).decode("utf-8")
    return im_b64


digits = load_digits()
tsne = TSNE(n_components=2, random_state=0)
projections = tsne.fit_transform(digits.data)
fig = px.scatter(
    projections, x=0, y=1,
    color=digits.target
)

app = dash.Dash(__name__)

app.layout = html.Div([
                       html.Div(id="output"),
                       dcc.Graph(id="fig1", figure=fig)
])

@app.callback(
    Output('output', 'children'),
    [Input('fig1', 'hoverData')])
def display_image(hoverData):
    if hoverData:
        idx = hoverData['points'][0]['pointIndex']
        im_b64 = numpy_to_b64(digits.images[idx])
        value = 'data:image/png;base64,{}'.format(im_b64)
        return html.Img(src=value, height='100px')
    return None

app.run_server(debug=True)

参考

Pythonの可視化ライブラリDashを使う 3 マウスホバーを活用する
https://qiita.com/OgawaHideyuki/items/b4e0c4f134c94037fd4f

Jupyter上でDashを使えるjupyter_dash
https://qiita.com/OgawaHideyuki/items/725f4ffd93ffb0d30b6c