pytorchで指定した位置要素を取得する

327 ワード

このコードの適用シーンは、あるbatchのsentence、あるpadding操作を経て、各文の実際の最後の単語を取得した場合です.
A = torch.Tensor([[[2, 3, 1], [1, 4, 0], [1, 0, 0]], [[2, 2, 0], [2, 0, 0], [3, 1, 4]]])
print(A.size())

B = torch.Tensor([[3, 2, 1], [2, 1, 3]]).long()
print(B.size())
B = B.view(2, 3, -1)
B = B - 1

C = torch.gather(A, 2, B)
print(C)