matplotlibのcolormapを使って二次元データを高速に画像化


まとめると

二次元データを画像化して保存したいとき、matplotlib.pyplot.imshow() は遅いので、matplotlibのColormapを使って直接保存できます。

問題設定

ある二次元データが(たくさん)あって、画像化したいとします。

import numpy as np
from matplotlib import pyplot as plt
data = np.arange(100).reshape(10,10)
plt.imshow(data)
plt.savefig("test.png")

しかしご存知の通りmatplotlibは遅いですね。遅いのはいちいちグラフィックレンダラーを呼んでレンダリングしてその結果を保存しているからです。
しかしやりたいことは、dataをなんらかの変換でRGBに押し込んで保存するだけです。

カラーマップを使えたら嬉しいよね、という動機を冗長に説明してしまったので冗長な説明が見たい人はクリック!

画像データとして押し込むだけなら、次のようにすればいいです。

import PIL.Image
def quantize(arr):
    arr = np.floor((arr - arr.min())  * 255 / (arr.max() - arr.min())).astype(np.uint8)

img = PIL.Image.fromarray(quantize(data))
img = img.resize([400,400], resample=PIL.Image.NEAREST) #拡大しているだけ
img.save("test.png")

カラー画像にするなら

img = np.zeros(data.shape+(3,), dtype=np.uint8) # キャンバスの用意
img[:,:,0] = 255 # 適当に1
img[:,:,1] =  255 - quantize(data) # 適当に2
img[:,:,2] = quantize(data) # 適当に3

img = PIL.Image.fromarray(img)
img = img.resize([400,400], resample=PIL.Image.NEAREST) #拡大しているだけ
img.save("test.png")

として「適当」って書いたところに適当な関数を割り当てればいいんですが、その適当な関数を自分で用意するのは厳しいものがある。というか初めに示したimshow()の通りの絵を得るにはどうしたらいいんでしょうね、ということになります。

Colormapの使い方

ご存知の通り(?)matplotlibにはカラーマップという概念があります。
値から色(RGB値)への写像です。plt.imshow(data, cmap=hoge)などと指定することができますよね。

カラーマップのバリエーションについてはblog:matplotlibのcmap(colormap)パラメータの一覧。など参照。

matplotlib内で、カラーマップはmatplotlib.colors.Colormapオブジェクトとして表現されています1

Colormapオブジェクトは、2d-arrayを受け取ってRGBA値の入った3d (2d*4ch) arrayを返す関数(に毛が生えたもの)として扱うことができます。
matplotlib.pyplot.get_cmap()にカラーマップの名前を指定することでColormapが返ってきます。つまり、

cmap = matplotlib.pyplot.get_cmap(NAME_OF_COLORMAP)
image = Colormap(raw_array)

という感じで、画像化が出来上がりです。
この画像をレンダラに表示させるのがplt.imshow()なので、画像が欲しいだけなら

PIL.Image.fromarray(image).save("filename.png")

の方がずっと早いということになりますね。例を示します。

import matplotlib.pyplot as plt

def minmax(arr):
    arr = (arr - arr.min()) / (arr.max() - arr.min())
    return arr

cmap = plt.get_cmap("viridis")
data = minmax(data)
img = cmap(data, bytes=True)
img = PIL.Image.fromarray(img)
img = img.resize([220,220], resample=PIL.Image.NEAREST) #拡大しているだけ
img.save("test.png")

  • 注意点1.

    • viridisというのはmatplotlibのデフォルトのカラーマップの名前です。
  • 注意点2.

    • 入力としては[0.,1.](float)か[0, 255](uint)の範囲のデータが期待されているので、先にminmax正規化によって値の範囲を調整しています。
  • 注意点3.

    • bytes=Trueを指定することで[0, 255]のデータが返ってきます。デフォルトでは[0.,1.]のデータが返ってきます。PIL.ImageではRGBAはでしか[0, 255]扱えないのでちょっと注意します。

以上で、plt.imshow()と同等のことがずいぶん速く実現できます。

参考: 自作colormapの作り方や既存カラーマップの中身を分析するやり方について以下が詳しいです。


  1. 実際にはColormapオブジェクトを継承したLinearSegmentedColormapListedColormapとして表現されています。