Pytorchはtfを実現する.gather()

976 ワード

1.tfを実現する.gather
pytorchでは、tfを実現する.gatherは簡単でselectを使うだけです.select(dim, index) → Tensor
たとえば、
import numpy as np
a = np.array([[1],[2],[3],[4],[5]])
b = torch.from_numpy(a)
indices = [ 1, 2, 0]
b[indices]

Output:
tensor([[2],
        [3],
        [1]])

参照先:
  • How to implement an equivalent of tf.gather in pytorch