pytoch常用関数max,eq説明
maxはtensorの行または列の最大の値を探し出す:
各行の最大値を見つけます。
(tensor([1.,2.,3.])、tenssor([0,0,0])
各列の最大値を検索:
(tensor([3.])、tenssor([2])
Tensorの比較eq等しい:
tensor(2)
補足知識:PyTorch-toch.eq、toch. ne、touch.gt、toch.lt、touch.ge、touch.le
flyfish
touch.eq、toch.ne、toch.gt、touch.lt、touch.ge、touch.le
以上は全部簡略です
パラメータはinput、other、out=Noneです。
元素ごとに比較するinputとother
帰りはtoch.BoolTensorです。
各行の最大値を見つけます。
import torch
outputs=torch.FloatTensor([[1],[2],[3]])
print(torch.max(outputs.data,1))
出力:(tensor([1.,2.,3.])、tenssor([0,0,0])
各列の最大値を検索:
import torch
outputs=torch.FloatTensor([[1],[2],[3]])
print(torch.max(outputs.data,0))
出力結果:(tensor([3.])、tenssor([2])
Tensorの比較eq等しい:
import torch
outputs=torch.FloatTensor([[1],[2],[3]])
targets=torch.FloatTensor([[0],[2],[3]])
print(targets.eq(outputs.data))
出力結果:
tensor([[ 0],
[ 1],
[ 1]], dtype=torch.uint8)
sum()を使用して等しい個数を統計します。
import torch
outputs=torch.FloatTensor([[1],[2],[3]])
targets=torch.FloatTensor([[0],[2],[3]])
print(targets.eq(outputs.data).cpu().sum())
出力結果:tensor(2)
補足知識:PyTorch-toch.eq、toch. ne、touch.gt、toch.lt、touch.ge、touch.le
flyfish
touch.eq、toch.ne、toch.gt、touch.lt、touch.ge、touch.le
以上は全部簡略です
パラメータはinput、other、out=Noneです。
元素ごとに比較するinputとother
帰りはtoch.BoolTensorです。
import torch
a=torch.tensor([[1, 2], [3, 4]])
b=torch.tensor([[1, 2], [4, 3]])
print(torch.eq(a,b))#equals
# tensor([[ True, True],
# [False, False]])
print(torch.ne(a,b))#not equal to
# tensor([[False, False],
# [ True, True]])
print(torch.gt(a,b))#greater than
# tensor([[False, False],
# [False, True]])
print(torch.lt(a,b))#less than
# tensor([[False, False],
# [ True, False]])
print(torch.ge(a,b))#greater than or equal to
# tensor([[ True, True],
# [False, True]])
print(torch.le(a,b))#less than or equal to
# tensor([[ True, True],
# [ True, False]])
以上のpytouch常用関数maxは、eqでは、小編集が皆さんに提供している内容の全てを説明しています。参考にしていただければと思います。どうぞよろしくお願いします。