CNNダウンサンプリング時の位置ずれ問題と簡単な解消方法


はじめに

典型的なCNNでは何らかの形でダウンサンプリングが行われます。
従来は2x2プーリングをストライド2x2で使われていましたが、最近は3x3畳み込みや3x3プーリングをストライド2x2で使われることが多いようです。

このようなダウンサンプリングをする時にちょっとした罠になるのが座標計算です。
各種のディープラーニングフレームワークでは、Convolutionのパディング方式として"SAME"モードが指定できると思いますが、ナイーブにこれを使うと、ダウンサンプリング層のたびに受容領域の位置が出力に対してずれていくという問題があります。
通常のディープなモデルを使っているうちは影響は少ないですが、高速化のためにネットワークを小型化したりするとこの問題が表出してきます。

問題の再現

下記のようなモデルがあったとします:

h = Conv2D(1, kernel_size=(3, 3), strides=(2, 2), padding="same")(x)
h = Conv2D(1, kernel_size=(3, 3), strides=(2, 2), padding="same")(h)
h = Conv2D(1, kernel_size=(3, 3), strides=(2, 2), padding="same")(h)
y = Conv2D(1, kernel_size=(3, 3), strides=(2, 2), padding="same")(h)

このとき、受容領域は次のようになります。
出力が左上と右下にあるときを比較すると、出力座標に対して受容領域の位置がずれていることがわかります。

出力セル:

対応する受容領域:

解決策

ダウンサンプリングの回数を予め考慮してはじめにパディングをしてしまうことで、簡単に解消できます。
なお入力画像サイズが偶数だと1ピクセルのずれは出ますので、その点は注意してください。

h = ZeroPadding2D(((7, 8), (7, 8)))(x)
h = Conv2D(1, kernel_size=(3, 3), strides=(2, 2), padding="valid")(h)
h = Conv2D(1, kernel_size=(3, 3), strides=(2, 2), padding="valid")(h)
h = Conv2D(1, kernel_size=(3, 3), strides=(2, 2), padding="valid")(h)
y = Conv2D(1, kernel_size=(3, 3), strides=(2, 2), padding="valid")(h)

同様に可視化します。治っていますね
出力セル:

対応する受容領域:

検証用コード

検証に使ったコードを記載しておきます:

cnn_receptive_field.py
from keras.layers import *
from keras import backend as K
import cv2

height, width = 128, 128
layers = 4
cell = 2**layers

# 受容領域を可視化するためのモデル
m = Input(batch_shape=(1, height//cell, width//cell, 1))
x = K.ones((1, height, width, 1))

# "same" mode
# h = x
# for i in range(layers):
#     h = Conv2D(1, kernel_size=(3, 3), strides=(2, 2), padding="same")(h)

# padding
h = ZeroPadding2D(((cell//2-1, cell//2), (cell//2-1, cell//2)))(x)
for i in range(layers):
    h = Conv2D(1, kernel_size=(3, 3), strides=(2, 2), padding="valid")(h)

d = Lambda(lambda x: x*m)(h)
g = K.gradients(d, x)

func = K.function([m], g)


# 可視化実施
count = 0
for i in range(height//cell):
    for j in range(width//cell):
        mask = np.zeros((1, height//cell, width//cell, 1))
        mask[0, i, j, 0] = 1
        grad = func([mask])[0]

        receptive_field = np.zeros_like(grad)
        receptive_field[grad != 0] = 1

        cv2.imshow("mask", mask[0])
        cv2.imshow("receptive_field", receptive_field[0])
        cv2.waitKey(0)

        # cv2.imwrite(f"img/mask{count:04d}.png", (mask[0]*255).astype(np.uint8))
        # cv2.imwrite(f"img/field{count:04d}.png",
        #             (receptive_field[0]*255).astype(np.uint8))
        # count += 1