tf.keras.layers.Layerカスタムレイヤ

5243 ワード

前編のclassをさらに理解するためにgithubを以下の例に挙げた.
import tensorflow as tf

class MyLayer(tf.keras.layers.Layer):

    def __init__(self, output_dim, **kwargs):
        self.output_dim = output_dim
        super(MyLayer, self).__init__(**kwargs)

    def build(self, input_shape):
        # Create a trainable weight variable for this layer.
        self.kernel = self.add_weight(name='kernel', 
                                      shape=(int(input_shape[1]), self.output_dim),
                                      initializer='uniform',
                                      trainable=True)
        print("build: shape of input: ", input_shape)
        print("build: shape of kernel: ", self.kernel)
        super(MyLayer, self).build(input_shape)  # Be sure to call this at the end

    def call(self, x):
        print("call: dot of x & kernel: ", tf.keras.backend.dot(x, self.kernel))
        return tf.keras.backend.dot(x, self.kernel)

    def compute_output_shape(self, input_shape):
        return (input_shape[0], self.output_dim)

kerasではLambdaで独自のレイヤを定義できるほか,継承する方法で定義することもできる.後者は格がずっと高いように見えます.
上記のdense層が正しいかどうかはともかく,クラスの継承としての学習のみである.
x=tf.keras.Input(12,dtype=tf.float32)
x_out=MyLayer(16)(x)

build: shape of input:  (?, 12)
build: shape of kernel:  
call: dot of x & kernel:  Tensor("my_layer/MatMul:0", shape=(?, 16), dtype=float32)

この層の使い方は層に対して、入力したのは層です!!!tensorではありません.エラーの例は次のとおりです.
>>> y=tf.constant([1,2,3])
>>> y_out=MyLayer(16)(y)
Traceback (most recent call last):
  File "", line 1, in 
    y_out=MyLayer(16)(y)
  File "D:\python\lib\site-packages\tensorflow_core\python\keras\engine\base_layer.py", line 824, in __call__
    self._maybe_build(inputs)
  File "D:\python\lib\site-packages\tensorflow_core\python\keras\engine\base_layer.py", line 2146, in _maybe_build
    self.build(input_shapes)
  File "D:/python/pycode/tf_keras_layers_Layer_.py", line 13, in build
    shape=(int(input_shape[1]), self.output_dim),
  File "D:\python\lib\site-packages\tensorflow_core\python\framework\tensor_shape.py", line 870, in __getitem__
    return self._dims[key]
IndexError: list index out of range

このクラスは入力層tensorのinput_を算出できるshape,callは論理層であり,この方法を明示的に呼び出す必要はない.
ここでcallのキーワードパラメータはinputsである必要があります.これは正規のように見えますが、kerasの他のクラス層、例えばDenseを連想することができます.
 |  build(self, input_shape)
 |      Creates the variables of the layer (optional, for subclass implementers).
 |      
 |      This is a method that implementers of subclasses of `Layer` or `Model`
 |      can override if they need a state-creation step in-between
 |      layer instantiation and layer call.
 |      
 |      This is typically used to create the weights of `Layer` subclasses.
 |      
 |      Arguments:
 |        input_shape: Instance of `TensorShape`, or list of instances of
 |          `TensorShape` if the layer expects a list of inputs
 |          (one instance per input).
 |  
 |  call(self, inputs)
 |      This is where the layer's logic lives.
 |      
 |      Arguments:
 |          inputs: Input tensor, or list/tuple of input tensors.
 |          **kwargs: Additional keyword arguments.
 |      
 |      Returns:
 |          A tensor or list/tuple of tensors.

buildはレイヤを作成する変数、add_Weightとかはこの層で定義されているに違いない.selfなどのselfを追加する必要がある.kernel,self.bias
callはこれらの変数を操作し、return結果はクラス層の最終結果である.
前回のクラスを考えると、入力したのはレイヤーではありません!!!しかし、同じエラーです.次のようにします.
x=tf.keras.Input(12,dtype=tf.float32)
inputs=list(map(Func(),x))

Traceback (most recent call last):
  File "", line 1, in 
    list(map(Func(),x))
  File "D:\python\lib\site-packages\tensorflow_core\python\framework\ops.py", line 547, in __iter__
    self._disallow_iteration()
  File "D:\python\lib\site-packages\tensorflow_core\python\framework\ops.py", line 543, in _disallow_iteration
    self._disallow_in_graph_mode("iterating over `tf.Tensor`")
  File "D:\python\lib\site-packages\tensorflow_core\python\framework\ops.py", line 523, in _disallow_in_graph_mode
    " this function with @tf.function.".format(task))
tensorflow.python.framework.errors_impl.OperatorNotAllowedInGraphError: iterating over `tf.Tensor` is not allowed in Graph execution. Use Eager execution or decorate this function with @tf.function.


エラー表示tf Tensorへの反復は図中でサポートされていない!!!
しかし、このクラス層は使えますし、Tensorと入力してもいいです.上のほうは顔を殴って、具体的な状況を具体的に分析しましょう.
>>> Func()(x)


>>> y=tf.constant([1.,2,3])
>>> y

>>> Func()(y)

恒等の役割でありresnetにおける恒等マッピングに類似している以上,このクラス層を除去することができる.
 
For Video Recommendation in Deep learning QQ Group 277356808
For Speech,Image, Video in deep learning QQ Group 868373192
I'm here waiting for you.