RuntimeError: Expected object of scalar type Float but got scalar type Int for argument #2 'other'

798 ワード

異なるタイプは等しくなく、比較できません.
import torch
a=torch.tensor([1],dtype=torch.float)
b=torch.tensor([1],dtype=torch.int)
#tensor  1           1  
#          ,      
if a==b:
    print((a==b))
    print((a==b).long().dtype)
else:
    print("not equal")

結果はエラー
    if a==b:
RuntimeError: Expected object of scalar type Float but got scalar type Int for argument #2 'other'

デフォルトの1とtensorでは、さまざまなタイプが等しい
import torch
a=torch.tensor([1],dtype=torch.float)
#b=torch.tensor([1],dtype=torch.int)
b=1
#tensor  1              
#          ,      
if a==b:
    print((a==b))
    print((a==b).long().dtype)
else:
    print("not equal")

結果
tensor([1], dtype=torch.uint8)
torch.int64