pytorchのtensorをsort, 同じ順で別のtensorを並べ替え, sortしたtensorを元に戻す.


環境

python3
pytorch 1.3.0

準備

>>> import torch
>>> a = torch.rand(3,3)
>>> a
tensor([[0.8395, 0.2678, 0.5021],
        [0.1397, 0.9839, 0.8991],
        [0.6298, 0.6101, 0.6841]])

torch.sortで並べ替える

# 指定した次元の方向にsortする
# sortする次元はdimで指定
>>> sorted, idx = torch.sort(a, dim = -1) # デフォルトはdim = -1 (-1は最後の次元. 3次元なら3と書くのと同じ)

# 戻り値は2つ
>>> sorted # 並び替えられたtensor
tensor([[0.2678, 0.5021, 0.8395],
        [0.1397, 0.8991, 0.9839],
        [0.6101, 0.6298, 0.6841]])

>>> idx # もとの行列をどういう順番に並び替えたのかの情報(後述のtorch.argsortの戻り値と同じ)
tensor([[1, 2, 0],
        [0, 2, 1],
        [1, 0, 2]])

# 例えば今回の例では, aとsortedの関係は, idxを使って
# sorted[0] = a[0,1], a[0,2], a[0,0]
# sorted[1] = a[1,0], a[1,2], a[1,1]
# sorted[2] = a[2,1], a[2,0], a[2,2]

# 降順に並び替える. descending
>>> sorted, idx = torch.sort(a, descending = True)
>>> sorted
tensor([[0.8395, 0.5021, 0.2678],
        [0.9839, 0.8991, 0.1397],
        [0.6841, 0.6298, 0.6101]]

torch.argsort: どういう風に並べ替えたのかの情報を得る

# torch.sortの2つめの戻り値と同じ
>>> idx = torch.argsort(a)
tensor([[1, 2, 0],
        [0, 2, 1],
        [1, 0, 2]])

torch.gather: argsortの結果を使って, 別のtensorを同じように並べ替える

# tensor a と同じサイズのtensor bを用意
>>> b = torch.rand_like(a) # 3*3 random

# まずtensor aを並べ替えます
>>> sorted, idx = torch.sort(a, dim = -1)

# torch.gather(input, dim, idx)
# tensor a を並べ替えたときの情報idxを用いて, aと同じ順でbを並べかえ
>>> torch.gather(b, -1, idx) # sort時と並べ替える方向は合わせてdim=-1

# つまりtorch.gatherでaをidxを用いて並べ替えた結果はsortedと同じ
>>> torch.gather(a, -1, idx) 
tensor([[0.2678, 0.5021, 0.8395],
        [0.1397, 0.8991, 0.9839],
        [0.6101, 0.6298, 0.6841]]) 

sortした配列をもとに戻す

# idxをargsortした結果(idxのidx)をgatherに突っ込む
>>> invidx = torch.argsort(idx) # idxのidx
>>> torch.gather(sorted, -1, invidx)
tensor([[0.8395, 0.2678, 0.5021],
        [0.1397, 0.9839, 0.8991],
        [0.6298, 0.6101, 0.6841]]) # aと同じになる