Positional Encoding (sinusoid) の実装に関するメモ


ポイント

  • Positional Encoding (sinusoid) を実装し、具体的な数値で確認。

レファレンス

1. Attention Is All You Need

数式


    
     (参照論文より引用)

サンプルコード

def get_lookup_table(length, n_units):
  lt = np.array([[pos / np.power(10000, 2.0 * i / n_units) \
                  for i in range(int(n_units / 2))] \
                          for pos in range(length)])
  lt = np.repeat(lt, 2, axis = 1)

  lt[:, 0::2] = np.sin(lt[:, 0::2])
  lt[:, 1::2] = np.cos(lt[:, 1::2])

  return tf.convert_to_tensor(lt)

def positional_encoding(batch_size, length, n_units):
  pos = tf.tile(tf.expand_dims(tf.range(length), 0), \
                   [batch_size, 1])

  lt = get_lookup_table(length, n_units)

  return tf.nn.embedding_lookup(lt, pos)


batch_size = 2
length = 10
n_units = 6

pe = positional_encoding(batch_size, length, n_units)

with tf.Session() as sess:
  print (sess.run(pe))

結果

[[[ 0. 1. 0. 1. 0.
1. ]
[ 0.84147098 0.54030231 0.04639922 0.99892298 0.00215443
0.99999768]
[ 0.90929743 -0.41614684 0.0926985 0.99569422 0.00430886
0.99999072]
[ 0.14112001 -0.9899925 0.1387981 0.9903207 0.00646326
0.99997911]
[-0.7568025 -0.65364362 0.18459872 0.98281398 0.00861763
0.99996287]
[-0.95892427 0.28366219 0.23000171 0.97319022 0.01077197
0.99994198]
[-0.2794155 0.96017029 0.27490927 0.96147017 0.01292625
0.99991645]
[ 0.6569866 0.75390225 0.31922465 0.94767907 0.01508047
0.99988628]
[ 0.98935825 -0.14550003 0.36285241 0.93184662 0.01723462
0.99985147]
[ 0.41211849 -0.91113026 0.40569857 0.91400693 0.0193887
0.99981202]]

[[ 0. 1. 0. 1. 0.
1. ]
[ 0.84147098 0.54030231 0.04639922 0.99892298 0.00215443
0.99999768]
[ 0.90929743 -0.41614684 0.0926985 0.99569422 0.00430886
0.99999072]
[ 0.14112001 -0.9899925 0.1387981 0.9903207 0.00646326
0.99997911]
[-0.7568025 -0.65364362 0.18459872 0.98281398 0.00861763
0.99996287]
[-0.95892427 0.28366219 0.23000171 0.97319022 0.01077197
0.99994198]
[-0.2794155 0.96017029 0.27490927 0.96147017 0.01292625
0.99991645]
[ 0.6569866 0.75390225 0.31922465 0.94767907 0.01508047
0.99988628]
[ 0.98935825 -0.14550003 0.36285241 0.93184662 0.01723462
0.99985147]
[ 0.41211849 -0.91113026 0.40569857 0.91400693 0.0193887
0.99981202]]]