pytoch常用関数max,eq説明


maxはtensorの行または列の最大の値を探し出す:
各行の最大値を見つけます。

import torch
outputs=torch.FloatTensor([[1],[2],[3]])
print(torch.max(outputs.data,1))
出力:
(tensor([1.,2.,3.])、tenssor([0,0,0])
各列の最大値を検索:

import torch
outputs=torch.FloatTensor([[1],[2],[3]])
print(torch.max(outputs.data,0))
出力結果:
(tensor([3.])、tenssor([2])
Tensorの比較eq等しい:

import torch

outputs=torch.FloatTensor([[1],[2],[3]])
targets=torch.FloatTensor([[0],[2],[3]])
print(targets.eq(outputs.data))
出力結果:

tensor([[ 0],
[ 1],
[ 1]], dtype=torch.uint8)
sum()を使用して等しい個数を統計します。

import torch

outputs=torch.FloatTensor([[1],[2],[3]])
targets=torch.FloatTensor([[0],[2],[3]])
print(targets.eq(outputs.data).cpu().sum())
出力結果:
tensor(2)
補足知識:PyTorch-toch.eq、toch. ne、touch.gt、toch.lt、touch.ge、touch.le
flyfish
touch.eq、toch.ne、toch.gt、touch.lt、touch.ge、touch.le
以上は全部簡略です
パラメータはinput、other、out=Noneです。
元素ごとに比較するinputとother
帰りはtoch.BoolTensorです。

import torch

a=torch.tensor([[1, 2], [3, 4]])
b=torch.tensor([[1, 2], [4, 3]])

print(torch.eq(a,b))#equals
# tensor([[ True, True],
#     [False, False]])

print(torch.ne(a,b))#not equal to
# tensor([[False, False],
#     [ True, True]])

print(torch.gt(a,b))#greater than
# tensor([[False, False],
#     [False, True]])

print(torch.lt(a,b))#less than
# tensor([[False, False],
#     [ True, False]])

print(torch.ge(a,b))#greater than or equal to
# tensor([[ True, True],
#     [False, True]])

print(torch.le(a,b))#less than or equal to
# tensor([[ True, True],
#     [ True, False]])
以上のpytouch常用関数maxは、eqでは、小編集が皆さんに提供している内容の全てを説明しています。参考にしていただければと思います。どうぞよろしくお願いします。