tf.batch_gatherの使い方メモ


tf.gatherと仲間たち

まず、gatheの処理をしてくれる仲間たちを紹介する。
- tf.gather 指定したaxisに沿ってスライスしたものをgatherする。
- tf.gather_nd 自由度が高く、任意の要素からgatherされる。
- tf.batch_gather batchに沿って処理される。indicesでスライスされる。
今回はバッチごとにgather処理をしてくれるtf.batch_gatherを使ってみた。

tf.batch_gatherの使い方

tf.batch_gatherの解説

tf.batch_gather(
    params,
    indices,
    name=None
)

prams(Tensor): [A1, ..., AN-1, AN, B1, ..., BM]
indices(Tensor): [A1, ..., AN-1, C]
result(Tensor): [A1, ..., AN-1, C, B1, ..., BM]

  • ANでスライスされ、A1〜AN-1の構造はそのまま維持される。
  • CにANからgatherされる値が入る。
  • Cの各値はANの大きさ以下でなければいけない。

詳しくはTensorFlowの公式ドキュメントを参考にしてもらいたい。

今回の処理

点群屋さんなので、点群を例に説明します。
8点から成る3次元点群を用意した。ここから4点を選びgatherするような処理を行う。
Bはバッチ、Nは点群の点、Cはチャンネル(3次元点群なのでxyz)

サンプルコード

batch_gather.py
import tensorflow as tf

#元の点群(2batch x 8点 x 3次元)
param = tf.constant([[[0.,0.,9.],[0.,1.,9.], [0.,2.,9.], [0.,3.,9.], [0.,4.,9.], [0.,5.,9.], [0.,6.,9.], [0.,7.,9.]],
                     [[1.,0.,9.],[1.,1.,9.], [1.,2.,9.], [1.,3.,9.], [1.,4.,9.], [1.,5.,9.], [1.,6.,9.], [1.,7.,9.]]]) # B x N x C(2 x 8 x 3)
print("param shape:     ",param.shape)

#gather元を選ぶテンソル(2batch x 4点)
indices = tf.constant([[1,0,0,4],
                       [1,3,4,6]])    # B x N
print("indices shape:   ",indices.shape)

#gatherする
result = tf.batch_gather(param, indices)
print("result shape:    ",result.shape)

sess = tf.Session()
sess.run(tf.global_variables_initializer())
print("param\n",sess.run(param))     #入力の確認
print("indices\n",sess.run(indices)) #indiciesの確認
print("result\n",sess.run(result))   #gatherの結果

実行結果

param shape:      (2, 8, 3)
indices shape:    (2, 4)
result shape:     (2, 4, 3)

param
 [[[0. 0. 9.]
  [0. 1. 9.]
  [0. 2. 9.]
  [0. 3. 9.]
  [0. 4. 9.]
  [0. 5. 9.]
  [0. 6. 9.]
  [0. 7. 9.]]

 [[1. 0. 9.]
  [1. 1. 9.]
  [1. 2. 9.]
  [1. 3. 9.]
  [1. 4. 9.]
  [1. 5. 9.]
  [1. 6. 9.]
  [1. 7. 9.]]]

indices
 [[1 0 0 4]
 [1 3 4 6]]

result
 [[[0. 1. 9.]
  [0. 0. 9.]
  [0. 0. 9.]
  [0. 4. 9.]]

 [[1. 1. 9.]
  [1. 3. 9.]
  [1. 4. 9.]
  [1. 6. 9.]]]