Pytorchはtfを実現する.gather()
976 ワード
1.tfを実現する.gather
pytorchでは、tfを実現する.gatherは簡単でselectを使うだけです.
たとえば、
Output:
参照先: How to implement an equivalent of tf.gather in pytorch
pytorchでは、tfを実現する.gatherは簡単でselectを使うだけです.
select(dim, index) → Tensor
たとえば、
import numpy as np
a = np.array([[1],[2],[3],[4],[5]])
b = torch.from_numpy(a)
indices = [ 1, 2, 0]
b[indices]
Output:
tensor([[2],
[3],
[1]])
参照先: