permute関数(Pytorch)

11442 ワード

  • permute関数の役割はtensorを転置することである.
  • import torch
    import torch.nn as nn
    
    x = torch.randn(1, 2, 3, 4)
    print(x.size())      
    print(x.permute(2, 1, 0, 3).size())
    
    

    ランダムに1 X 2 X 3 X 4の4次元ベクトルを生成し、permute関数のパラメータは回転後のベクトル位置を表す.例えば、元のベクトル(1,2,3,4)、1の下付きは0、2の下付きは1、3の下付きは2、4の下付きは3である.x.permute(2,1,0,3)では,2は元の下表が2の数字を表す3が1位(つまりN),1は元の下表が1の数字を表す2が2位(つまりC),0は元の下表が0の数字を表す1が3位(つまりH),3は元の下表が3の数字を表す4が4位(つまりW)となる.結果は次のとおりです.
    torch.Size([1, 2, 3, 4])   #   tensor
    torch.Size([3, 2, 1, 4])   #    tensor
    
  • **torch.Transpose()は2 D転置のみを操作できます.**これはtorchではありません.transpose()は2次元ベクトルにのみ作用し、1回に2次元の回転しかできないことを意味し、複数の次元の回転が必要な場合はtranspose()を複数回呼び出す必要がある.例えば上記のtensore[1,2,3,4]をtensor[3,4,1,2]に転置し、transposeを使用するには以下の
  • を行う必要がある.
    x.transpose(0,2).transpose(1,3)
    
  • view()関数の役割のメモリは連続的でなければならない.操作数が連続的に記憶されていない場合は、操作前にcontiguous()を実行し、tensorをメモリに連続的に分布する形式にしなければならない.
  • tensor転置の実際の効果を直接見てみましょう.
  • import torch
    import torch.nn as nn
    import numpy as np
    
    y = np.array([[[1, 2, 3], [4, 5, 6]]]) # 1X2X3
    y_tensor = torch.tensor(y)
    y_tensor_trans = y_tensor.permute(2, 0, 1) # 3X1X2
    print(y_tensor.size())
    print(y_tensor_trans.size())
    
    print(y_tensor)
    print(y_tensor_trans)
    print(y_tensor.view(1, 3, 2)) 
    
    
    torch.Size([1, 2, 3])
    torch.Size([3, 1, 2])
    tensor([[[1, 2, 3],
             [4, 5, 6]]])
    tensor([[[1, 4]],
    
            [[2, 5]],
    
            [[3, 6]]])
    tensor([[[1, 2],
             [3, 4],
             [5, 6]]])
    
  • view()関数は連続操作メモリであり、1、2、3、4、5、6を連続的に割り当て、上記は1 X 2 X 3を1 X 3 X 2に変更する.1 X 2 X 3になると、結果は以下の通り:
  • tensor([[[1, 2, 3],
             [4, 5, 6]]])