【pytorch】tensor 型の確認


pytorchの使い方の確認のため書く。

is_tensor

torch.is_tensor(obj)のobjがtensorならTrueを返す。

>>> import torch
>>> a=torch.tensor([[0, 1], [2, 3], [4, 5]])
>>> a
tensor([[0, 1],
        [2, 3],
        [4, 5]])
>>> torch.is_tensor(a)
True
>>> torch.is_tensor(1)
False

is_storage

torch.is_storage(obj)のobjがPyTorch storage objectならTrueを返す。
PyTorch storage objectが何を示しているかはまだ勉強中。

>>> torch.is_storage(a)
False
>>> torch.is_storage(1)
False

is_complex

torch.is_complex(obj)のobjが複素数ならTrueを返す。
注意
以下のコードでcfloatが原因でエラーになるときがある。それはpytorhのバージョンを上げることで解決できる。
https://pytorch.org/でpytorchのバージョンを上げられる。

>>> b=torch.rand((3,3), dtype=torch.cfloat)
>>> b
tensor([[0.1955+0.3707j, 0.5087+0.8311j, 0.2987+0.4621j],
        [0.5064+0.7313j, 0.9797+0.8714j, 0.0260+0.7468j],
        [0.1468+0.5747j, 0.9827+0.3058j, 0.8373+0.9638j]])
>>> torch.is_complex(a)
False
>>> torch.is_complex(b)
True
>>> torch.is_complex(1)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: is_complex(): argument 'input' (position 1) must be Tensor, not int

is_floating_point

torch.is_floating_point(obj)のobjがfloatならTrueを返す。
objがcfloatならFalseである。

>>> c=torch.rand((3,3), dtype=torch.float)
>>> c
tensor([[0.7839, 0.8118, 0.8087],
        [0.4022, 0.0625, 0.7638],
        [0.3203, 0.0413, 0.3117]])
>>> torch.is_floating_point(a)
False
>>> torch.is_floating_point(b)
False
>>> torch.is_floating_point(c)
True
>>> torch.is_floating_point(1.0)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: is_floating_point(): argument 'input' (position 1) must be Tensor, not float