tf-keras-visを使ってGradCAM、GradCAM++、ScoreCAM、Faster-ScoreCAM


はじめに

ディープラーニングでの予測結果における特徴部位可視化でよく使われるGradCAMなどを出力してくれるツールtf-keras-visですが、すごく簡単に出力できるうえGradCAMだけでなくGradCAM++、ScoreCAM、Faster-ScoreCAM、Vanilla Saliency、SmoothGradといろんな種類の可視化を行うことができます
そんなtf-keras-visですが、公式のサンプルだとloss関数のクラスインデックスを書き換えて利用する方式なので、クラスインデックス・予測画像・認識モデルの3つの引数で取得するサンプルを作りました

tf-keras-vis
https://github.com/keisen/tf-keras-vis

検証環境

この記事の内容は、以下の環境で検証しました。
Python 3.6.9
TensorFlow 2.4.0-rc0
tf-keras-vis 0.5.3

環境準備

すでに必要なモジュールが入っているならとばしてください

pip install --upgrade tf-keras-vis matplotlib

モジュールの読み込み

import os
import glob
import numpy as np 
import matplotlib.pyplot as plt
import tensorflow as tf

モデル用モジュール

from tensorflow.keras.applications.vgg16 import VGG16 as Model
from tensorflow.keras.applications.vgg16 import preprocess_input

モデル読み込み

model = VGG16(include_top=False, weights='imagenet')
#model = tf.keras.models.load_model('mymodel.h5')
model.summary()

tf-keras-vis用モジュール

from tf_keras_vis.saliency import Saliency
from tf_keras_vis.gradcam import Gradcam
from tf_keras_vis.gradcam import GradcamPlusPlus
from tf_keras_vis.scorecam import ScoreCAM
from tf_keras_vis.utils import normalize
from matplotlib import cm

特徴可視化マップ取得関数

SmoothGrad

def GetSmoothGrad(cls_index, img, model):
  def loss(output):
    return (output[0][cls_index])
  def model_modifier(m):
    m.layers[-1].activation = tf.keras.activations.linear
    return m
  saliency = Saliency(model,model_modifier=model_modifier,clone=False)
  cam = saliency(loss, img, smooth_samples=20, smooth_noise=0.20)
  cam = normalize(cam)
  heatmap = np.uint8(cm.jet(cam[0])[..., :3] * 255)
  return heatmap

GradCAM

def GetGradCAM(cls_index, img, model):
  def loss(output):
    return (output[0][cls_index])
  def model_modifier(m):
    m.layers[-1].activation = tf.keras.activations.linear
    return m
  gradcam = Gradcam(model,model_modifier=model_modifier,clone=False)
  cam = gradcam(loss, img, penultimate_layer=-1)
  cam = normalize(cam)
  heatmap = np.uint8(cm.jet(cam[0])[..., :3] * 255)
  return heatmap

GradCAM++

def GetGradCAMPlusPlus(cls_index, img, model):
  def loss(output):
    return (output[0][cls_index])
  def model_modifier(m):
    m.layers[-1].activation = tf.keras.activations.linear
    return m
  gradcam = GradcamPlusPlus(model,model_modifier=model_modifier,clone=False)
  cam = gradcam(loss, img, penultimate_layer=-1)
  cam = normalize(cam)
  heatmap = np.uint8(cm.jet(cam[0])[..., :3] * 255)
  return heatmap

ScoreCAM

def GetScoreCAM(cls_index, img, model):
  def loss(output):
    return (output[0][cls_index])
  def model_modifier(m):
    m.layers[-1].activation = tf.keras.activations.linear
    return m
  scorecam = ScoreCAM(model,model_modifier=model_modifier,clone=False)
  cam = scorecam(loss, img, penultimate_layer=-1)
  cam = normalize(cam)
  heatmap = np.uint8(cm.jet(cam[0])[..., :3] * 255)
  return heatmap

Faster ScoreCAM

def GetFasterScoreCAM(cls_index, img, model):
  def loss(output):
    return (output[0][cls_index])
  def model_modifier(m):
    m.layers[-1].activation = tf.keras.activations.linear
    return m
  scorecam = ScoreCAM(model,model_modifier=model_modifier,clone=False)
  cam = scorecam(loss, img, penultimate_layer=-1, max_N=10)
  cam = normalize(cam)
  heatmap = np.uint8(cm.jet(cam[0])[..., :3] * 255)
  return heatmap

取得テスト

from tensorflow.keras.preprocessing.image import load_img

IMAGE_PATH = 'Image.JPG'
CAT_CLASS_INDEX = 0

# Load image
img = load_img(IMAGE_PATH, target_size=(224, 224))
# Preparing input data
X = preprocess_input(np.array(img))

#Get SmoothGrad
heatmap = GetSmoothGrad(CAT_CLASS_INDEX, X, model)
plt.figure(figsize=(20,20))
plt.subplot(1, 3, 1)
plt.title('SmoothGrad')
plt.imshow(heatmap)

#Get GradCAM++
heatmap = GetGradCAMPlusPlus(CAT_CLASS_INDEX, X, model)
plt.subplot(1, 3, 2)
plt.title('GradCAMPlusPlus')
plt.imshow(heatmap)

#Get FasterScoreCAM
heatmap = GetFasterScoreCAM(CAT_CLASS_INDEX, X, model)
plt.subplot(1, 3, 3)
plt.title('FasterScoreCAM')
plt.imshow(heatmap)
plt.show()

注意点

CAT_CLASS_INDEXは分類のクラスidではなくindexであることにご注意ください
また各関数に投げる画像はサンプルにあるように
img = load_img(IMAGE_PATH, target_size=(224, 224))
X = preprocess_input(np.array(img))
学習時と同じ大きさと状態の画像を投げてください