tensorflowにおけるlayernormalizationについて

9018 ワード

循環ニューラルネットワークはLayerNormalizationにあまり適していないことが知られている.
LayerNormalizationについて:https://blog.csdn.net/liuxiao214/article/details/81037416
LayerNormalizationにおける同層ニューロン入力は同じ平均値と分散を有し,異なる入力サンプルは異なる平均値と分散を有する.すなわち、入力がピクチャであると仮定し、ある層のshapeが(m,h,w,c)であり、mがロットサイズであり、hが高く、wが幅であり、cがチャネル数であると仮定する.
式(x-mean)/stdは、LayerNormalizationによって適用される.xは入力(m,h,w,c)であり,このmeanのshapeは(m,),stdのshapeは(m,),これにより各サンプルに異なる平均値と分散が保証され,同時に正規化が完了する.
ループニューラルネットワークでは,(m,t,feature)と入力し,tが時間ステップを表すと仮定するとmeanのshapeは何であるか.stdのmeanは何ですか?
論文によればmeanのshapeは(m,t),stdのshapeは(m,t)であり,各サンプルの各時間ステップに独自の平均値と標準差があることが分かる.
実験検証:
1 tensorflow 1.xバージョン
観察tensorflow 1.14ドキュメント.
定義
tf.contrib.layers.layer_norm(
    inputs,
    center=True,
    scale=True,
    activation_fn=None,
    reuse=None,
    variables_collections=None,
    outputs_collections=None,
    trainable=True,
    begin_norm_axis=1,
    begin_params_axis=-1,
    scope=None
)

と解釈する
Given a tensor  inputs  of rank  R( R) , moments are calculated and normalization is performed over axes  begin_norm_axis ... R - 1 . Scaling and centering, if requested, is performed over axes  begin_params_axis .. R - 1 .
By default,  begin_norm_axis = 1  and  begin_params_axis = -1 , meaning that normalization is performed over all but the first axis (the  HWC  if  inputs  is  NHWC ), while the  beta  and  gamma  trainable parameters are calculated for the rightmost axis (the  C  if  inputs  is  NHWC ). Scaling and recentering is performed via broadcast of the  beta  and  gamma  parameters with the normalized tensor.
The shapes of  beta  and  gamma  are  inputs.shape[begin_params_axis:] , and this part of the inputs' shape must be fully defined.
この解釈を見て、デフォルトパラメータでは、ピクチャについて、入力shapeが(m,h,w,c)であると仮定すると、求めた平均値shapeは(m,)、begin_norm_axisのデフォルトは1です.しかし、循環ニューラルネットワークにはbegin_を付与すべきである.norm_axis=−1であり,計算したmeanのshapeは(m,t)であった.
実験検証(tf 1.14):
import tensorflow as tf

x1 = tf.convert_to_tensor(
    [[[18.369314, 2.6570225, 20.402943],
      [10.403599, 2.7813416, 20.794857]],
     [[19.0327, 2.6398268, 6.3894367],
      [3.921237, 10.761424, 2.7887821]],
     [[11.466338, 20.210938, 8.242946],
      [22.77081, 11.555874, 11.183836]],
     [[8.976935, 10.204252, 11.20231],
      [-7.356888, 6.2725096, 1.1952505]]])
mean_x = tf.reduce_mean(x1, axis=-1)
print(mean_x.shape)  # (4, 2)
mean_x = tf.expand_dims(mean_x, -1)

std_x = tf.math.reduce_std(x1, axis=-1)
print(std_x.shape)  # (4, 2)
std_x = tf.expand_dims(std_x, -1)

#     
la_no1 = (x1-mean_x)/std_x

x = tf.placeholder(tf.float32, shape=[4, 2, 3])
la_no = tf.contrib.layers.layer_norm(
      inputs=x, begin_norm_axis=-1, begin_params_axis=-1)
with tf.Session() as sess1:
    sess1.run(tf.global_variables_initializer())
    x1 = sess1.run(x1)

    #     
    print(sess1.run(la_no1))
    '''
    [[[ 0.5749929  -1.4064412   0.83144826]
    [-0.1250188  -1.1574404   1.2824593 ]]

    [[ 1.3801126  -0.9573896  -0.422723  ]
    [-0.5402143   1.4019758  -0.86176145]]

    [[-0.36398557  1.3654773  -1.0014919 ]
    [ 1.4136491  -0.6722269  -0.74142253]]

    [[-1.2645671   0.08396867  1.1806016 ]
    [-1.3146634   1.108713    0.20595042]]]
    '''

    # tensorflow  
    print(sess1.run(la_no, feed_dict={x: x1}))
    '''
    [[[ 0.574993   -1.4064413   0.8314482 ]
    [-0.12501884 -1.1574404   1.2824591 ]]

    [[ 1.3801126  -0.9573896  -0.422723  ]
    [-0.5402143   1.4019756  -0.86176145]]

    [[-0.36398554  1.3654773  -1.0014919 ]
    [ 1.4136491  -0.67222667 -0.7414224 ]]

    [[-1.2645674   0.08396816  1.1806011 ]
    [-1.3146634   1.108713    0.20595042]]]
    '''


両者の結果は一致することが分かった.すなわちRNN系列入力をx(m,t,feature)とし,式(x−mean)/stdを適用するとmeanのshapeは(m,t),stdのshapeは(m,t)とする.
2 tensorflow 2.xバージョン
tensorflowを2にアップグレードx tfを発見した.contribは消えた
2.1.0ドキュメントの表示
tf.keras.layers.LayerNormalization(
    axis=-1, epsilon=0.001, center=True, scale=True, beta_initializer='zeros',
    gamma_initializer='ones', beta_regularizer=None, gamma_regularizer=None,
    beta_constraint=None, gamma_constraint=None, trainable=True, name=None, **kwargs
)

多くのパラメータと1.xが違います.
公式解釈
Normalize the activations of the previous layer for each given example in a batch independently, rather than across a batch like Batch Normalization. i.e. applies a transformation that maintains the mean activation within each example close to 0 and the activation standard deviation close to 1.
ポイントはパラメータaxisaxis : Integer or List/Tuple. The axis that should be normalized (typically the features axis).
このaxisのデフォルトは-1であり、Batchnormalizationと混同されやすいという錯覚を与え、ここでは(m,t,feature)と入力すると仮定する.
ではaxis=−1では,求めたmeanのshapeは(m,t)であり,各サンプルの各時間ステップに対して平均値を求める.
実験検証(tf 2.1.0)
import tensorflow as tf

gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

tf.random.set_seed(1234)

x1 = tf.random.normal((4, 2, 3), mean=10, stddev=10.0)
tf.print(x1)
'''
[[[18.3693142 2.65702248 20.4029427]
  [10.4035988 2.78134155 20.794857]]

 [[19.0327 2.63982677 6.38943672]
  [3.92123699 10.7614241 2.78878212]]

 [[11.4663382 20.2109375 8.24294567]
  [22.7708092 11.5558739 11.183836]]

 [[8.97693539 10.2042522 11.2023096]
  [-7.35688782 6.27250957 1.19525051]]]
'''

# tensorflow  
x2 = tf.keras.layers.LayerNormalization()(x1)
tf.print(x2)
'''
[[[0.574988365 -1.40643 0.8314417]
  [-0.125017628 -1.15742981 1.28244758]]

 [[1.38009858 -0.957379758 -0.422718644]
  [-0.540192485 1.40191901 -0.861726582]]

 [[-0.363978326 1.36545086 -1.00147223]
  [1.41362476 -0.672215044 -0.74140954]]

 [[-1.26380491 0.0839173347 1.17988873]
  [-1.31464267 1.10869551 0.205947176]]]

'''

x_mean = tf.reduce_mean(x1, axis=-1)
x_mean = tf.expand_dims(x_mean, -1)

x_std = tf.math.reduce_std(x1, axis=-1)
x_std = tf.expand_dims(x_std, -1)

#     
x3 = (x1-x_mean)/x_std
tf.print(x3)
'''
[[[0.574992955 -1.40644133 0.831448257]
  [-0.125018805 -1.15744042 1.28245926]]

 [[1.38011265 -0.957389593 -0.422723]
  [-0.5402143 1.40197563 -0.861761391]]

 [[-0.363985568 1.36547732 -1.0014919]
  [1.41364908 -0.672226906 -0.741422534]]

 [[-1.2645669 0.0839686543 1.18060148]
  [-1.31466341 1.10871303 0.205950424]]]
'''

結果がほぼ一致する微細差は,epsilon(0で割ることを防止する)などの導入に起因することが分かった.
ただし、ピクチャではデフォルトのaxisは使用できませんが、axis=(1,2,3)のように計算されたmeanのshapeを(m,)に設定する必要があります.
実験検証:
import tensorflow as tf

gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

tf.random.set_seed(1234)

x1 = tf.random.normal((4, 2, 2, 3), mean=10, stddev=10.0)

x2 = tf.keras.layers.LayerNormalization(axis=(1, 2, 3))(x1)
# tensorflow  
tf.print(x2)
'''
[[[[1.13592184 -1.01682484 1.41455]
   [0.0445363373 -0.999791801 1.46824622]]

  [[1.22681248 -1.01918077 -0.505445421]
   [-0.84361434 0.0935621038 -0.998772383]]]


 [[[0.241638556 1.41170228 -0.189664766]
   [1.75422382 0.253618807 0.203838587]]

  [[-0.0914538875 0.0727662146 0.206310436]
   [-2.27698731 -0.453317314 -1.13267565]]]


 [[[0.660608232 -1.50688756 -1.25433147]
   [-0.108726673 0.251018792 1.11019969]]

  [[1.55131137 -0.91692245 1.37210977]
   [0.350336075 -0.651512 -0.857205331]]]


 [[[-1.13443768 0.891957879 -1.49989474]
   [0.853844702 2.05934501 -0.13168712]]

  [[-0.669010222 -1.08240855 0.9768731]
   [-0.31119889 -0.043616 0.0902313]]]]
'''

x_mean = tf.reduce_mean(x1, axis=(1, 2, 3))
# print(x_mean.shape)  # (10,)

x_mean = tf.reshape(x_mean, (-1, 1, 1, 1))
# print(x_mean.shape)  # (10, 1, 1, 1)


x_std = tf.math.reduce_std(x1, axis=(1, 2, 3))
x_std = tf.reshape(x_std, (-1, 1, 1, 1))

x3 = (x1-x_mean)/x_std
#     
tf.print(x3)  #            
'''
[[[[1.13593256 -1.01683426 1.4145633]
   [0.0445368551 -0.99980104 1.46826017]]

  [[1.22682405 -1.01919019 -0.50545007]
   [-0.843622148 0.0935630798 -0.998781621]]]


 [[[0.241640702 1.41171491 -0.18966648]
   [1.75423956 0.253621072 0.203840405]]

  [[-0.091454722 0.0727668479 0.206312269]
   [-2.27700782 -0.453321397 -1.1326859]]]


 [[[0.660612881 -1.50689745 -1.25433969]
   [-0.108727202 0.25102067 1.11020732]]

  [[1.55132198 -0.916928411 1.37211931]
   [0.350338638 -0.651516199 -0.857210875]]]


 [[[-1.13444376 0.891962826 -1.49990284]
   [0.853849471 2.05935621 -0.131687731]]

  [[-0.669013798 -1.08241439 0.976878583]
   [-0.31120047 -0.0436161309 0.0902319]]]]
'''

ほぼ一致していることがわかります.