permute関数(Pytorch)
import torch
import torch.nn as nn
x = torch.randn(1, 2, 3, 4)
print(x.size())
print(x.permute(2, 1, 0, 3).size())
ランダムに1 X 2 X 3 X 4の4次元ベクトルを生成し、permute関数のパラメータは回転後のベクトル位置を表す.例えば、元のベクトル(1,2,3,4)、1の下付きは0、2の下付きは1、3の下付きは2、4の下付きは3である.x.permute(2,1,0,3)では,2は元の下表が2の数字を表す3が1位(つまりN),1は元の下表が1の数字を表す2が2位(つまりC),0は元の下表が0の数字を表す1が3位(つまりH),3は元の下表が3の数字を表す4が4位(つまりW)となる.結果は次のとおりです.
torch.Size([1, 2, 3, 4]) # tensor
torch.Size([3, 2, 1, 4]) # tensor
x.transpose(0,2).transpose(1,3)
import torch
import torch.nn as nn
import numpy as np
y = np.array([[[1, 2, 3], [4, 5, 6]]]) # 1X2X3
y_tensor = torch.tensor(y)
y_tensor_trans = y_tensor.permute(2, 0, 1) # 3X1X2
print(y_tensor.size())
print(y_tensor_trans.size())
print(y_tensor)
print(y_tensor_trans)
print(y_tensor.view(1, 3, 2))
torch.Size([1, 2, 3])
torch.Size([3, 1, 2])
tensor([[[1, 2, 3],
[4, 5, 6]]])
tensor([[[1, 4]],
[[2, 5]],
[[3, 6]]])
tensor([[[1, 2],
[3, 4],
[5, 6]]])
tensor([[[1, 2, 3],
[4, 5, 6]]])