PyTorchのtoch.catの使い方

2411 ワード

1.字面理解:
touch.catは2枚の枚数をつなぎ合わせて、catはconcatnateという意味です。つまりつなぎ合わせて、繋がっています。
2.例理解

>>> import torch
>>> A=torch.ones(2,3) #2x3   (  )                   
>>> A
tensor([[ 1., 1., 1.],
    [ 1., 1., 1.]])
>>> B=2*torch.ones(4,3)#4x3   (  )                  
>>> B
tensor([[ 2., 2., 2.],
    [ 2., 2., 2.],
    [ 2., 2., 2.],
    [ 2., 2., 2.]])
>>> C=torch.cat((A,B),0)#   0( )  
>>> C
tensor([[ 1., 1., 1.],
     [ 1., 1., 1.],
     [ 2., 2., 2.],
     [ 2., 2., 2.],
     [ 2., 2., 2.],
     [ 2., 2., 2.]])
>>> C.size()
torch.Size([6, 3])
>>> D=2*torch.ones(2,4) #2x4   (  )
>>> C=torch.cat((A,D),1)#   1( )  
>>> C
tensor([[ 1., 1., 1., 2., 2., 2., 2.],
    [ 1., 1., 1., 2., 2., 2., 2.]])
>>> C.size()
torch.Size([2, 7])
上に2つのテンソルAとBがあります。それぞれ2行3列、4行3列です。彼らは2次元のテンソルです。二次元しかないので、このようにtouch.catでつなぎ合わせる時に二種類のつなぎ合わせがあります。行ごとにつづり合わせて連結します。いわゆる次元0と次元1.
C=touch.cat((A,B)、0)は次元0でAとBをつなぎ合わせるという意味です。つまり縦につなぎ合わせて、AとBの下に行きます。このとき注意が必要です。列の数は一致していなければなりません。つまり、次元1の値は同じです。ここは全部3列です。列は揃えることができます。スティッチングしたCの第0次元は2次元の0値と、つまり2+4=6です。
C=touch.cat((A,B),1)は次元1(列)でAとBをつなぎ合わせるという意味です。つまり横につなぎ合わせて、A左Bの右側です。この時注意しなければなりません。行数は必ず一致します。つまり、次元0の値は同じで、ここは全部2行です。つなぎ合わせたCの第1次元は2次元の1値と、すなわち3+4=7です。
2次元の例から、touch.cat((A,B),dim)を使う場合は、スペル次元dimの数値が異なる以外の次元の数値は同じで、配置できます。
3.例
画像の深さ学習では、3チャンネルのRGBカラー画像と単チャンネルの階調図が一般的である。テンソルsizeはcxhxwであり、チャネル数x画像高さx画像幅である。touch.catで2枚の画像をつなぎ合わせる時、一般的に画像の大きさが一致することが要求されますが、チャンネル数が一致しない、つまりhとwは同じで、cは違ってもいいです。もちろん実際には3つのつなぎ合わせがありますが、他の2つはあまり見られないようです。例えば、経典ネットワーク構造:U-Net

中は4回のtouch.catを使って、その中のcopy and crop操作はtouch.catを通じて実現しました。アップサンプリングによってオリジナル画像hとwを2倍にし、左から直接コピーしたのと同じh,wの画像をつなぎ合わせて見ることができます。このようにすれば、元の構造情報を効果的に利用することができる。
4.まとめ
touch.cat((A,B),dim)を使う場合は、スペル次元dimの値が異なる以外の次元の数値は同じであれば、配置できます。
補足知識:PyTorchのconcatつまりtoch.catの例
余計なことを言わないで、コードを見てください。

import torch
a = torch.ones([1,2])
b = torch.ones([1,2])
torch.cat([a,b],1)
 1 1 1 1
[torch.FloatTensor of size 1x4]
以上のPyTorchのtoch.catは小编で皆さんに提供した内容の全部です。参考にしていただければと思います。どうぞよろしくお愿いします。