torch.Argmax関数の説明


torch.Argmax()関数
Argmax関数:torch.argmax(input, dim=None, keepdim=False)は、指定する次元の最大値のシーケンス番号を返し、dimが与える定義は、the demention to reduceである.つまりdimという次元の、この次元の最大値になるindexです.1)dimの異なる値は異なる次元を表す.特にdim=0は2次元の列を表し、dim=1は2次元の行列の中で行を表す.広く言えば、1つのマトリクスが数次元であっても、例えば、1つのマトリクス次元が以下のように(d 0,d 1,...,dn−1)である場合、dim=0はd 0に対応する第1の次元を表し、dim=1は第2の次元に対応する第2の次元を表し、1回に類推する.2)dimの値がどういう意味なのかを知るのはまだだめで,関数の中でこのdimが与えられると何が起こるかを知る必要がある.
この2つを組み合わせるとdimの関数における役割がわかります.次に、上記の第2点を2つの例を挙げて説明する.
例1:torch.argmax()関数のdimは、この次元が消えることを示します.この消失はどういう意味ですか?公式英語解釈はdim(int)-the dimension to reduce.argmaxは最大値を得るシーケンス番号インデックスであることを知っています.次元(d 0,d 1)のマトリクスでは、各行の最大数の行の列番号を要求したいと思っています.最後に、次元(d 0,1)のマトリクスを得ます.この時、列は消えてしまいます.
したがって、各行の最大のカラム番号を要求するにはdim=1を指定し、カラムを必要とせず、ローのsizeを保持すればよいことを示します.各列の最大行標を求めたい場合はdim=0を指定し、行を必要としないことを示します.example 1
import torch
a=torch.tensor(
              [
                  [1, 5, 5, 2],
                  [9, -6, 2, 8],
                  [-3, 7, -9, 1]
              ])
b=torch.argmax(a,dim=0)
print(b)
print(a.shape)

#結果:
tensor([1, 2, 0, 1])
torch.Size([3, 4])

dim=0次元で3、すなわちその3組のデータを比較し、各列の最大行標を求めるため[1,2,0,4]
example 2 3 D座標で
import torch
a=torch.tensor([
              [
                  [1, 5, 5, 2],
                  [9, -6, 2, 8],
                  [-3, 7, -9, 1]
              ],
 
              [
                  [-1, 7, -5, 2],
                  [9, 6, 2, 8],
                  [3, 7, 9, 1]
              ]])
b=torch.argmax(a,dim=0)
print(b)
print(a.shape)

"""
tensor([[0, 1, 0, 1],
        [1, 1, 1, 1],
        [1, 1, 1, 1]])
torch.Size([2, 3, 4])"""

#dim=0,         ,      [3*4]       ,          ,      [3*4]              

b=torch.argmax(a,dim=1)
“”“
tensor([[1, 2, 0, 1],
        [1, 2, 2, 1]])
torch.Size([2, 3, 4])”“”
#dim=1,         ,    :      [2*4];
"""[1, 5, 5, 2],
   [9, -6, 2, 8],
   [-3, 7, -9, 1];
       ,    [1,2,0,1];    [1,2,2,1];"""
b=torch.argmax(a,dim=2)
"""
tensor([[2, 0, 1],
        [1, 0, 2]])
"""
#dim=2,         ,    :      [2*3]
"""
   [1, 5, 5, 2],
   [9, -6, 2, 8],
   [-3, 7, -9, 1];
       
[2,0,1],       “”“