tensorflow 2におけるカスタムレイヤのbuild()とcall()についての探究

15863 ワード

文書ディレクトリ

  • 0 0 x 00前セグメントコード
  • 0 x 01厨房牛1-_init__
  • 0 x 02厨房牛2--build()
  • 0 x 03厨房牛3--call()
  • 0 x 04最終出力
  • 0 x 05まとめ
  • 0 x 00先のコード


    質問:ネットワーク層をカスタマイズする際に、build()call()が何に使われているのかを明らかにしたいのですが、なぜ呼び出しに成功したのか、外部で定義する必要はありません.
    
    # coding=utf-8
    '''
    @ Summary: test call
    @ Update:  
    
    @ file:    test.py
    @ version: 1.0.0
    
    @ Author:  [email protected]
    @ Date:    2020/6/11  3:48
    '''
    from __future__ import absolute_import, division, print_function
    import tensorflow as tf
    tf.keras.backend.clear_session()
    import tensorflow.keras as keras
    import tensorflow.keras.layers as layers
    
    class MyLayer(layers.Layer):
       def __init__(self, unit=32):
           super(MyLayer, self).__init__()
           self.unit = unit
    
       def build(self, input_shape):
           self.weight = self.add_weight(shape=(input_shape[-1], self.unit),
                                         initializer=keras.initializers.RandomNormal(),
                                         trainable=True)
           self.bias = self.add_weight(shape=(self.unit,),
                                       initializer=keras.initializers.Zeros(),
                                       trainable=True)
    
       def call(self, inputs):
           return tf.matmul(inputs, self.weight) + self.bias
    
    my_layer = MyLayer(3)
    x = tf.ones((3,5))
    out = my_layer(x)
    print(out)
    
    

    0 x 01厨房牛1-init


    クラスオブジェクトの定義
    
    my_layer = MyLayer(3)
    
    

    OK、ここには何の問題もありませんMyLayer()クラスの__init__()メソッドのみが呼び出され、self.units = 3という変数が得られた.
    クラス内のbuild()およびcall()メソッドは、ここでは呼び出されていません.
    
       def __init__(self, unit=32):
    
           #  , , 
           super(MyLayer, self).__init__() 
           self.unit = unit
    
    

    0 x 02厨房牛2–build()


    入力オブジェクトを初期化します.
    
    x = tf.ones((3,5))
    
    

    この一歩も何の問題もなく続けていく
    
    out = my_layer(x)
    
    

    ここ、問題が来ました.
    最初の質問に戻ります.なぜ外部呼び出しなしでbuild()call()などの関数を実行できるのですか.
    回答:Layer()クラスには__call__()の魔法の方法があります(上記の2つの関数はすでにtfによってこの関数の下に統合されています)、自動的に呼び出されますので、外部呼び出しを使わないで、具体的にどのように呼び出すか、ソースコードを読んでください
    次にmy_layerに入力し、xに入力します.build()メソッドを呼び出します.
    
       def build(self, input_shape):
           self.weight = self.add_weight(shape=(input_shape[-1], self.unit),
                                         initializer=keras.initializers.RandomNormal(),
                                         trainable=True)
           self.bias = self.add_weight(shape=(self.unit,),
                                       initializer=keras.initializers.Zeros(),
                                       trainable=True)
    
    

    2つの訓練可能な値を初期化して、それぞれ重みとバイアスで、ok、この部分の問題は解決しました
    ついでに別の問題を解決します:どうしてbuild()の方法がありますか
    回答:build()ネットの重みの次元をカスタマイズすることができて、入力によって重みの次元を指定することができて、重みが固定するならば、build()の方法を使うことを避けることができます
    もう一つ注意すべき点は、self.built = Trueです.
    このパラメータは、build()の運転開始時にFalseであり、build()メソッドが先に呼び出されることを保証するために、call()メソッドが呼び出される
    終了時にTrueに自動的に割り当てられ、build()メソッドが1回のみ呼び出されることを保証します.
    
    class MyLayer(layers.Layer):
    
        def __init__(self, input_dim=32, unit=32):
    
            super(MyLayer, self).__init__()
    
            self.weight = self.add_weight(shape=(input_dim, unit),
    
                                         initializer=keras.initializers.RandomNormal(),
    
                                         trainable=True)
    
            self.bias = self.add_weight(shape=(unit,),
    
                                       initializer=keras.initializers.Zeros(),
    
                                       trainable=True)
    
        
    
        def call(self, inputs):
    
            return tf.matmul(inputs, self.weight) + self.bias
    
    

    0 x 03厨房牛3–call()

    build()メソッドを呼び出した後、初期化の重みとバイアス値を取得し、次に順方向伝播を行い、公式サイトでは論理機能関数を実現すると言っていますが、私は前者と言うのが好きで、もっと理解しやすいです.
    
       def call(self, inputs):
           return tf.matmul(inputs, self.weight) + self.bias
    
    

    アクティブ化関数計算を含まないレイヤの出力値を返します.

    0 x 04最終出力

    
    print(out)
    
    

    0 x 05まとめ


    Layerのサブクラスは一般的に以下のように実現される.
  • init():super()は、入力に関係のないすべての変数
  • を初期化します.
  • build():レイヤ内のパラメータと変数を初期化するための
  • call():順方向伝播を定義する
  • 1回目の訓練はまずModel(x)を計算し、それからModel(x).build(input)を計算し、最後にModel(x).call(input)を計算し、2回目以降は中間ステップをスキップした.