Kerasでmodel.saveしようとしてmust override `get_config`エラーが出たときの対処


はじめに

本記事ではKerasでmodel.save(あるいはmodel.to_json)しようとしてXX has arguments in `__init__` and therefore must override `get_config`.が出たときの対処法を紹介したいと思います。

背景と原因

KerasではDense層やConv層など事前定義されたレイヤーが多数存在し、これらを組み合わせることで基本的なモデルを設計します。
しかし、より発展的には、カスタムレイヤーを自分で実装しモデルに追加することになります。
例えば最新の論文で発表された仕組みを利用したいには場合、Kerasの事前定義レイヤーには存在せず、Githubから引用したり、自分で実装したりする必要があります。
(カスタムレイヤーの実装について興味のある場合は、こちらの公式Exampleをご確認ください。)
あるいは、初学者においては、kaggleのkernelなどで公開されているスクリプトをフォークした際に、知らず知らずのうちにカスタムレイヤーを含んだモデルを利用しているかもしれません。(私自身もそのようにして今回のエラーに直面しました。)

さて、XX(カスタムレイヤー名) has arguments in `__init__` and therefore must override `get_config`というエラーは、このカスタムレイヤーを含んだモデルに対して正しく対処できていない際に、Kerasから「そんなレイヤー知らないよ」と怒られて生じるものなのです。

解決方法

カスタムレイヤーのクラス内でget_config()をオーバーライドすることで解決できます。
より具体的には、カスタムレイヤーのクラスの__init__の引数を辞書にして、親クラスのconfigに追加して返すようなget_config()を定義します。
これが意味するところは、__init__の引数はこのカスタムレイヤーの設計書のようなものですから、勝手に作ったカスタムレイヤーの仕組みをKerasに明示的に教えてあげていることに相当します。

ちなみに、このようにして保存されたモデルはロードする際にもカスタムレイヤーをcustom_objectsアーギュメントで明示的に示す必要があります。
方法は非常に簡単で、以下のように行います。

load_model('my_model.h5', custom_objects={'NameOfCustomLayer': NameOfCustomLayer})

具体例

KaggleのこちらのPublic Kernelを例に説明いたします。
[GLRec] ResNet50 ArcFace (TF2.2)

このスクリプトのうち、実際にモデルの定義は以下で行われます。
backboneとなるモデルはResNet50でKerasに事前定義されています。(weightも今回のようにローカルに保存したものを使用するだけでなく、Kerasのパッケージで取得できます。)
また、pooling層やdropout層も事前定義されたものです。

このなかでmargin層だけは独自にインスタンス化していることがわかります。これがこのモデルのカスタム層です。

create_model.py

def create_model(input_shape,
                 n_classes,
                 dense_units=512,
                 dropout_rate=0.0,
                 scale=30,
                 margin=0.3):

    backbone = tf.keras.applications.ResNet50(
        include_top=False,
        input_shape=input_shape,
        weights=('../input/imagenet-weights/' +
                 'resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5')
    )

    pooling = tf.keras.layers.GlobalAveragePooling2D(name='head/pooling')
    dropout = tf.keras.layers.Dropout(dropout_rate, name='head/dropout')
    dense = tf.keras.layers.Dense(dense_units, name='head/dense')

    margin = ArcMarginProduct(
        n_classes=n_classes,
        s=scale,
        m=margin,
        name='head/arc_margin',
        dtype='float32')

    softmax = tf.keras.layers.Softmax(dtype='float32')

    image = tf.keras.layers.Input(input_shape, name='input/image')
    label = tf.keras.layers.Input((), name='input/label')

    x = backbone(image)
    x = pooling(x)
    x = dropout(x)
    x = dense(x)
    x = margin([x, label])
    x = softmax(x)
    return tf.keras.Model(
        inputs=[image, label], outputs=x)

margin層のクラスであるArcMarginProductを確認します。
すると、tf.keras.layers.Layerを継承したカスタムレイヤーであることがわかります。
(ちなみに、実装している技術はArcFaceといいます。)

このように独自定義されたカスタムレイヤー内で、get_config()を正しくオーバーライドしていないとき、model.saveをすると冒頭のエラーに直面するのでした。

今回のKernelではget_config()がクラス内で定義されていないので、そのままsaveしようとするとエラーがでます。

custom_layer.py
class ArcMarginProduct(tf.keras.layers.Layer):
    '''
    Implements large margin arc distance.

    Reference:
        https://arxiv.org/pdf/1801.07698.pdf
        https://github.com/lyakaap/Landmark2019-1st-and-3rd-Place-Solution/
            blob/master/src/modeling/metric_learning.py
    '''
    def __init__(self, n_classes, s=30, m=0.50, easy_margin=False,
                 ls_eps=0.0, **kwargs):

        super(ArcMarginProduct, self).__init__(**kwargs)

        self.n_classes = n_classes
        self.s = s
        self.m = m
        self.ls_eps = ls_eps
        self.easy_margin = easy_margin
        self.cos_m = tf.math.cos(m)
        self.sin_m = tf.math.sin(m)
        self.th = tf.math.cos(math.pi - m)
        self.mm = tf.math.sin(math.pi - m) * m

    def build(self, input_shape):
        super(ArcMarginProduct, self).build(input_shape[0])

        self.W = self.add_weight(
            name='W',
            shape=(int(input_shape[0][-1]), self.n_classes),
            initializer='glorot_uniform',
            dtype='float32',
            trainable=True,
            regularizer=None)

    def call(self, inputs):
        X, y = inputs
        y = tf.cast(y, dtype=tf.int32)
        cosine = tf.matmul(
            tf.math.l2_normalize(X, axis=1),
            tf.math.l2_normalize(self.W, axis=0)
        )
        sine = tf.math.sqrt(1.0 - tf.math.pow(cosine, 2))
        phi = cosine * self.cos_m - sine * self.sin_m
        if self.easy_margin:
            phi = tf.where(cosine > 0, phi, cosine)
        else:
            phi = tf.where(cosine > self.th, phi, cosine - self.mm)
        one_hot = tf.cast(
            tf.one_hot(y, depth=self.n_classes),
            dtype=cosine.dtype
        )
        if self.ls_eps > 0:
            one_hot = (1 - self.ls_eps) * one_hot + self.ls_eps / self.n_classes

        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output *= self.s
        return output

そこで、以下のような変更を加える必要があります。

具体的には、get_config()をオーバーライドし、__init__の引数と親クラスのconfigを返しています。

new_custom_layer.py

class ArcMarginProduct(tf.keras.layers.Layer):
    '''
    Implements large margin arc distance.

    Reference:
        https://arxiv.org/pdf/1801.07698.pdf
        https://github.com/lyakaap/Landmark2019-1st-and-3rd-Place-Solution/
            blob/master/src/modeling/metric_learning.py
    '''
    def __init__(self, n_classes, s=30, m=0.50, easy_margin=False,
                 ls_eps=0.0, **kwargs):

        super(ArcMarginProduct, self).__init__(**kwargs)

        self.n_classes = n_classes
        self.s = s
        self.m = m
        self.ls_eps = ls_eps
        self.easy_margin = easy_margin
        self.cos_m = tf.math.cos(m)
        self.sin_m = tf.math.sin(m)
        self.th = tf.math.cos(math.pi - m)
        self.mm = tf.math.sin(math.pi - m) * m


### Start 追加されたコード
    def get_config(self):
        config = {
            "n_classes" : self.n_classes,
            "s" : self.s,
            "m" : self.m,
            "easy_margin" : self.easy_margin,
            "ls_eps" : self.ls_eps
        }
        base_config = super().get_config()
        return dict(list(base_config.items()) + list(config.items()))

###  End       

    def build(self, input_shape):
        super(ArcMarginProduct, self).build(input_shape[0])

        self.W = self.add_weight(
            name='W',
            shape=(int(input_shape[0][-1]), self.n_classes),
            initializer='glorot_uniform',
            dtype='float32',
            trainable=True,
            regularizer=None)

    def call(self, inputs):
        X, y = inputs
        y = tf.cast(y, dtype=tf.int32)
        cosine = tf.matmul(
            tf.math.l2_normalize(X, axis=1),
            tf.math.l2_normalize(self.W, axis=0)
        )
        sine = tf.math.sqrt(1.0 - tf.math.pow(cosine, 2))
        phi = cosine * self.cos_m - sine * self.sin_m
        if self.easy_margin:
            phi = tf.where(cosine > 0, phi, cosine)
        else:
            phi = tf.where(cosine > self.th, phi, cosine - self.mm)
        one_hot = tf.cast(
            tf.one_hot(y, depth=self.n_classes),
            dtype=cosine.dtype
        )
        if self.ls_eps > 0:
            one_hot = (1 - self.ls_eps) * one_hot + self.ls_eps / self.n_classes

        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output *= self.s
        return output

そしてモデルのロードは以下のように行う必要があります。

load_model.py

loaded_model =keras.models.load_model("path_to_model", custom_objects = {"ArcMarginProduct": ArcMarginProduct})

参考

Kerasでカスタムレイヤーを作成する方法
kerasでカスタムレイヤーのシリアライズを行う
NotImplementedError: Layers with arguments in __init__ must override get_config