matplotlibを使ってJupyterLab上でインタラクティブな3Dグラフを作成する


1. はじめに

  • インタラクティブなグラフは、数値の挙動を確認する時に何かと便利です。
  • また、3Dのグラフを表示した際には、インタラクティブに見る方向を変えられると便利です。
  • Jupyter notebookを使っていた頃は、ipywidgetという仕組みを使って、インタラクティブに動くグラフを作れたのですが、現時点ではJupyterLabでは対応していないように思います。
  • JupyterLab上でこのような仕組みについて紹介した記事が、私が調べた限りではあまり多くはなかったように思いましたので、投稿してみようと思いました。

2. 概要

  • matplotlibのwidgetという仕組みを使うと、GUI backendに依存せずに、Axesに配置するような形でインタラクティブなWidgetを作ることが可能です。

  • 以下、widgetのドキュメントからの抜粋です。

Widgets that are designed to work for any of the GUI backends. All of these widgets require you to predefine a matplotlib.axes.Axes instance and pass that as the first arg.

  • JupyterLab上で、こんな感じのグラフが簡単に作れます。(このグラフはブラウザのJupyerLab上で動いています。)

3. 具体的なやり方

 ということで、上のようなグラフの作成方法をご説明します。なお、以下をJupyterLab上で作りましたが、先ほどのご説明の通り、matplotlib.widgetはGUI Backendに依存しないということなので、Jupyter notebookはもちろん、LinuxやWindows, MacのGUI上でも同様の結果になるものと思います。(手元では試してませんが)

(1) 必要なライブラリをインポート

まず、必要なライブラリのインポートを行います。また、バックエンドをwidgetに設定する必要があります。

%matplotlib widget
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.widgets import Slider
import numpy as np

(2) サンプルとして表示するデータを作成

ここでは、例として、$z=x^2+2y^2$のグラフを書いてみることにします。

x = y = np.arange(-20, 20, 0.5)
X, Y = np.meshgrid(x, y)
Z = X*X + 2 * Y*Y

(3) Figureを作成

 ご存知の通り、Figureは、matplotlibでグラフを書くキャンパスのようなものです。Figureの上にAxesを載せ、その上にPlotをしていくイメージです。

# Figureの設定
fig = plt.figure(figsize=(5,5))

(4) Axesの作成

次に、Figure上にAxesを作成します。

  • 今回は、3Dグラフのプロット用のAxesを1つ、Slider表示用のAxesを2つ作成します。
  • どんなやり方でもいいのですが、今回はgridspecを使ってみました、
  • 下をコードを実行することで、図のように、Figureを20x20に分割し、1-17行目に3Dグラフ描写用のAxesを、18, 19行目にそれぞれSlider用のAxesを配置できます。

# Figureの中に3Dグラフ、Slider用のAxesを追加
axcolor = 'gold'
gs = fig.add_gridspec(20, 20)
ax1 = fig.add_subplot(gs[:17,:], projection='3d')
ax_slider_z = fig.add_subplot(gs[18,:], facecolor=axcolor)
ax_slider_xy = fig.add_subplot(gs[19,:], facecolor=axcolor)

(5) Sliderの設定

次に、Sliderオブジェクトを作成します。以下のコードでは、

  • slider_zとslider_xyというSiderオブジェクトを作成し、それぞれ、上で作成したax_slider_z, ax_slider_xy上に描写
  • Sliderの動く範囲は-180〜180
  • Sliderの初期値としてz0(=0)を設定
  • Sliderの動く幅としてdelta(=10)を設定

を行っています。

# Sliderの設定
z0 = 0
xy0 = 0
delta = 10
slider_z = Slider(ax_slider_z, 'z-axis', -180, 180, valinit=z0, valstep=delta)
slider_xy = Slider(ax_slider_xy, 'xy-axis', -180, 180, valinit=xy0, valstep=delta)

(6) 3Dグラフの見る角度の初期値を設定後、3Dグラフを表示。

次に、3Dグラフの、z軸周り、xy平面周りの見える角度の初期値を設定します。この数値を変化させることで、3Dグラフの見る角度を動かすことができます。

# 3Dグラフの見る方向の初期値を設定
ax1.view_init(elev=z0, azim=xy0)

# 3Dグラフを表示
ax1.plot_surface(X, Y, Z)

(7) Sliderを動かした時に呼ばれるコールバック関数を作成

次に、Sliderを動かした時に呼ばれるコールバック関数を作成します。ここで、3Dグラフを見る角度を指定し、グラフを再描写することで動くグラフを作成することができます。

# Sliderを動かした時に呼ばれるコールバック関数
def view_change(val):

    sz = slider_z.val
    sxy = slider_xy.val
    ax1.view_init(elev=sxy, azim=sz)
    fig.canvas.draw_idle()    

(8) コールバック関数の設定

最後に、Sliderオブジェクトに先ほど作成したコールバック関数を設定します。

slider_z.on_changed(view_change)
slider_xy.on_changed(view_change)
plt.show()

以上で、JupyterLab上で簡単に動くグラフを作成することができます!
こちらにgistも置いておきましたので、よろしければお使いください。(上でご説明したものと同じですが。。。)