pytorch Tensor操作チートシート


概要

毎回調べてしまうpytorchのtensorの操作をまとめました
公式のドキュメンテーション以上の内容はありません

環境

pytorch 1.3.1

Tensorの基本操作

list, ndarrrayからTensorを生成する

a = [[1,2,3],[4,5,6]]
a_np = np.array(a)

# tensorにする
b = torch.tensor(a_list)
b = torch.tensor(a_np) # listからもndarrayからも変換可能
b = torch.from_numpy(a_np) # a_npとbはメモリが共有されるので, 片方を変更するともう片方も変わる


# データの型を指定できる dtype
>>> b = torch.tensor(a, dtype=float)
>>> b = torch.tensor(a).to(torch.float64) # 上に同じ
型一覧はここ. [公式](https://pytorch.org/docs/stable/tensors.html)


# 逆伝搬する場合はrequires_grad=Trueに(defaultはFalse). Variableは非推奨
>>> b = torch.tensor(a, requires_grad=True, dtype=float) # floatにする必要あり
>>> b = torch.tensor(a, dtype=float).requires_grad_() # 上と同じ
>>> b = torch.tensor(a, dtype=float).detach() # requires_grad = Falseになる


# GPUに乗せる device
>>> b = torch.tensor(a, device=0)
>>> b = torch.tensor(a, device=torch.device('cuda')) # 上と同じ
>>> b = torch.tensor(a).cuda(0) # これも同じ
>>> b.is_cuda
True
>>> b.device
device(type='cuda', index=0)

Tensorの初期化

# 0配列 torch.zeros()
>>> b = torch.zeros(3,3,5) # 配列のサイズを指定
>>> b = torch.zeros(3,3,5,requires_grad=True) # requires_gradも指定可能
>>> c = torch.zeros_like(b) # 引数のtensorと同じサイズの0配列を作る. z

# 1配列 torch.ones()
>>> torch.ones(3,3,5)
>>> c = torch.ones_like(b) # 引数のtensorと同じサイズの1配列を作る.

# 任意の数値の配列 torch.full()
>>> torch.full((2,3),fill_value=4) # full_likeもある
tensor([[4., 4., 4.],
        [4., 4., 4.]])

# 空配列 torch.empty()
>>> torch.empty(3,3,5)

# ランダム torch.rand()
>>> torch.rand(3,3,5) # [0 1)一様分布
>>> torch.randn(3,3,5) # N(0,1)正規分布
>>> torch.randint(0,5,(3,3,5)) # 0-4の整数から均一な確率でサンプル
>>> torch.randperm(10) # 0-9のrandom permutation
tensor([4, 6, 2, 8, 5, 9, 7, 3, 0, 1])

# arangeとlinspace
>>> torch.arange(1,4,0.5)
tensor([1.0000, 1.5000, 2.0000, 2.5000, 3.0000, 3.5000])
>>> torch.linspace(1,4,5)
tensor([1.0000, 1.7500, 2.5000, 3.2500, 4.0000])

# 対角行列
>>> torch.eye(3)
tensor([[1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.]])
>>> torch.eye(2,3)
tensor([[1., 0., 0.],
        [0., 1., 0.]])

# 1次元配列から対角行列を作る
>>> a = torch.rand(3)
>>> torch.diag(a)
tensor([[0.6432, 0.0000, 0.0000],
        [0.0000, 0.2017, 0.0000],
        [0.0000, 0.0000, 0.1462]])
>>> torch.diag(a,diagonal = 1) # オフセットをつけることもできる
tensor([[0.0000, 0.6432, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.2017, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.1462],
        [0.0000, 0.0000, 0.0000, 0.0000]])

# 逆に対角成分を取ってくる場合は
>>> torch.diagonal(a)

Tensorのサイズ

# size()
>>> b = torch.zeros(2,3,4)
>>> b.size()
torch.Size([2, 3, 4])
>>> b.ndim
3

# reshape()
>>> b = torch.zeros(2,3,4)
>>> b.reshape(6,4) # b.reshape(6,-1)と同じ
tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]])
>>> b.view(6,4) # b.reshapeと結果は同じ

# torch.flatten(tensor) # reshape(-1)とおなじ
>>> a = torch.rand(2,2,2)
>>> torch.flatten(a)
tensor([0.6368, 0.7954, 0.7549, 0.4453, 0.6321, 0.8776, 0.5591, 0.9163])

# 次元を削減 squeeze()
>>> a = torch.zeros(2,1,3,1)
>>> b = torch.squeeze(a).size()
torch.Size([2, 3])
>>> c = torch.squeeze(a, 1).size() # dim = 1のみsqueeze
torch.Size([2, 3, 1])

# 次元を増やす unsqueeze()
>>> torch.unsqueeze(b,0).size() # unsqueeze(tensor,dim)
torch.Size([1, 2, 3])

# 転置
>>> a = torch.rand(2,3,4,5)
## torch.tensor.T
>>> a.T.size() #tensor.T
torch.Size([5, 4, 3, 2])

## torch.transpose(tensor, dim1, dim2)
>>> torch.transpose(a,0,1).size()
torch.Size([3, 2, 4, 5])

## torch.t(two_dim_tensor)
>>> a = torch.rand(2,3) # torch.tは2次元のtensorに対する処理
>>> torch.t(a).size()
torch.Size([3, 2])

# ブロードキャスト
## torch.broadcast_tensors(list of tensor)
>>> a = torch.rand(1,4)
>>> b = torch.rand(2,1)
>>> a_, b_ = torch.broadcast_tensors(a,b)
>>> b_
tensor([[0.1983, 0.1983, 0.1983, 0.1983],
        [0.7776, 0.7776, 0.7776, 0.7776]])
## torch.dotを使っても自動でbroadcastされるので, onesと掛け合わせて
>>> a*torch.ones_like(b) # a_と同じ
>>> b*torch.ones_like(a) # b_と同じ

# meshgrid



## broadcastの上の例はmeshgridでも実現できる.
>>> a_, b_ = torch.meshgrid(a.flatten(), b.flatten()) # 入力は1次元のtensor

# リピート repeat_interleave
>>> torch.repeat_interleave(b, 4, dim=1) # repeats = 4
tensor([[0.1983, 0.1983, 0.1983, 0.1983],
        [0.7776, 0.7776, 0.7776, 0.7776]])
>>> torch.repeat_interleave(b, 4, dim=0)
tensor([[0.1983],
        [0.1983],
        [0.1983],
        [0.1983],
        [0.7776],
        [0.7776],
        [0.7776],
        [0.7776]])

Tensorの要素のアクセス, スライス

# スライス: ndarrayと同じ

# flip
>>> a = torch.rand(2,2,2)
>>> a
tensor([[[0.9357, 0.3026],
         [0.1707, 0.6531]],

        [[0.9908, 0.1727],
         [0.1737, 0.9758]]])
>>> torch.flip(a,(0,)) # dim=0に沿ってflip
tensor([[[0.9908, 0.1727],
         [0.1737, 0.9758]],

        [[0.9357, 0.3026],
         [0.1707, 0.6531]]])
>>> torch.flip(a,(0,1,2))
tensor([[[0.9758, 0.1737],
         [0.1727, 0.9908]],

        [[0.6531, 0.1707],
         [0.3026, 0.9357]]])

# roll キャタピラみたいな感じで要素をずらしていく
>>> a
tensor([[0, 1],
        [2, 3],
        [4, 5],
        [6, 7]])
>>> torch.roll(a,1,dims=0)
tensor([[6, 7],
        [0, 1],
        [2, 3],
        [4, 5]])
>>> torch.roll(a,2,dims=0)
tensor([[4, 5],
        [6, 7],
        [0, 1],
        [2, 3]])
>>> torch.roll(a,-1,dims=0)
tensor([[2, 3],
        [4, 5],
        [6, 7],
        [0, 1]])
>>> torch.roll(a,(1,1),dims=(0,1)) # 複数の次元に沿って同時にrollすることも可能
tensor([[7, 6],
        [1, 0],
        [3, 2],
        [5, 4]])

Tensorの結合分割

# torch.cat # np.hstack/np.vstackにあたる
>>> a = torch.zeros(2,3)
>>> b = torch.ones(2,3)
>>> torch.cat((a,b),0)
tensor([[0., 0., 0.],
        [0., 0., 0.],
        [1., 1., 1.],
        [1., 1., 1.]])
>>> torch.cat((a,b),1)
tensor([[0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1.]])

# torch.stack # 新たな次元方向に結合する
>>> a.size()
torch.Size([2, 3])
>>> torch.stack((a,a,a,a),0).size()
torch.Size([4, 2, 3])
>>> torch.stack((a,a,a,a),1).size()
torch.Size([2, 4, 3])
>>> torch.stack((a,a,a,a),2).size()
torch.Size([2, 3, 4])

要素の並び替え

# sort
>>> a = torch.rand(3,3)
>>> a
tensor([[0.8395, 0.2678, 0.5021],
        [0.1397, 0.9839, 0.8991],
        [0.6298, 0.6101, 0.6841]])

>>> sorted, idx = torch.sort(a) # dim = -1 (last dim) by default
>>> sorted
tensor([[0.2678, 0.5021, 0.8395],
        [0.1397, 0.8991, 0.9839],
        [0.6101, 0.6298, 0.6841]])
>>> idx # もとの行列のどのidxを取るか
tensor([[1, 2, 0],
        [0, 2, 1],
        [1, 0, 2]])

>>> sorted, idx = torch.sort(a, descending = True) # descending order
>>> sorted
tensor([[0.8395, 0.5021, 0.2678],
        [0.9839, 0.8991, 0.1397],
        [0.6841, 0.6298, 0.6101]]

# argsort
>>> idx = torch.argsort(a)
>>> idx # torch.sort(a)の2番目の戻り値と同じ
tensor([[1, 2, 0],
        [0, 2, 1],
        [1, 0, 2]])

# sortidxを使って並び替え gather(input, dim, idx)
>>> sorted, idx = torch.sort(a, -1) # dim = -1 (same as default value)
>>> torch.gather(a, -1, idx) sortと合わせてdim=-1
tensor([[0.2678, 0.5021, 0.8395],
        [0.1397, 0.8991, 0.9839],
        [0.6101, 0.6298, 0.6841]]) # sortedと同じものができる

## aとは別の配列bを, aを並び替えたのと同じ順に並び替える
>>> b = torch.rand_like(a) # 3*3 random
>>> torch.gather(b, 1, idx)

## sortした配列をもとに戻す
>>> invidx = torch.argsort(idx)
>>> torch.gather(sorted, -1, invidx)
tensor([[0.8395, 0.2678, 0.5021],
        [0.1397, 0.9839, 0.8991],
        [0.6298, 0.6101, 0.6841]]) # aと同じになる

要素のフィルタリング/検索

# where
>>> a = torch.rand(3,3)
>>> a
tensor([[0.0551, 0.0676, 0.7144],
        [0.7593, 0.6779, 0.0390],
        [0.9562, 0.4836, 0.1249]])

>>> torch.where(a > .5)
(tensor([0, 1, 1, 2]), tensor([2, 0, 1, 0])) # (0,2),(1,0),(1,1),(2,0)

>>> a[torch.where(a > .5)] # aの0.5より大きい要素を抽出
tensor([0.7144, 0.7593, 0.6779, 0.9562])

# torch.where(cond, if cond is true, if cond is false)
>>> torch.where(a > .5, torch.ones(3,3), torch.zeros(3,3))
tensor([[0., 0., 1.],
        [1., 1., 0.],
        [1., 0., 0.]])

>>> torch.where(a != 0)
>>> torch.nonzero(a, as_tuple=True) # 上と同じ

# masked_select # 真偽値のtensorをもとに, 対応する要素を抽出
>>> a > .5
tensor([[False, False,  True],
        [ True,  True, False],
        [ True, False, False]])
>>> torch.masked_select(a, a>.5) # masked_select(input, cond)

tensorの各要素への算術演算

>>> torch.abs(tensor) # 絶対値
>>> torch.ceil(tensor) # 切り上げ
>>> torch.floor(tensor) # 切り下げ -1.8 -> -2
>>> torch.round(tensor) # 四捨五入 -1.8 -> -1
>>> torch.tranc(tensor) # 小数点以下の切り捨て(0に近いほうへの丸め) -1.8 -> -1
>>> torch.flac(tensor) # 小数点以下. 1.3 -> 0.3, -4.6 -> -0.6
>>> torch.clamp(tensor,min,max) # min maxでのクリップ

>>> torch.log(tensor) # 自然対数
>>> torch.log10(tensor) # 常用対数
>>> torch.log2(tensor)
>>> torch.pow(tensor, power) # tensorのpower乗
>>> torch.pow(power, tensor) # powerのtensor乗
>>> torch.sigmoid(tensor)
>>> torch.sign(tensor) # 符号関数sgn. 0には0が返る

要素ごとの値の判定

>>> a = torch.tensor([1,float("Inf"),float("-Inf"),float("nan")])
>>> torch.isfinite(a)
tensor([ True, False, False, False])
>>> torch.isinf(a)
tensor([False,  True,  True, False])
>>> torch.isnan(a)
tensor([False, False, False,  True])

tensorの統計処理

>>> a
tensor([[0.5084, 0.7959, 0.1889, 0.1577, 0.8705],
        [0.0891, 0.0288, 0.2702, 0.0880, 0.6026]])

# min, max, argmin, argmax
## dimを指定しないと1次元の長い配列になる
>>> torch.min(a)
tensor(0.0288)
>>> torch.argmin(a)
tensor(6)

## dimを指定すると, val, idxが両方返ってくる
>>> val,idx=torch.min(a,1)
>>> val
tensor([0.1577, 0.0288])
>>> idx
tensor([3, 1])
>>> torch.argmin(a,1)
tensor([3, 1])

# 小さい順にk番目を取ってくる
>>> val,idx=torch.kthvalue(a, k=2, dim=1) # k=1ならtorch.min(a,1)と同じ
>>> val
tensor([0.1889, 0.0880])
>>> idx
tensor([2, 3])

# 上位k個を取ってくる
>>> val, idx = torch.topk(a, k=2, dim=1)
>>> val
tensor([[0.8705, 0.7959],
        [0.6026, 0.2702]])
>>> idx
tensor([[4, 1],
        [4, 2]])

# 統計量を計算する. APIはmin maxなどとだいたい一緒
>>> torch.mean(a)
tensor(0.3600)
>>> torch.mean(a,1)
tensor([0.5043, 0.2158])
>>> torch.median(a,1) # 戻り値 val, idx
>>> torch.mode(a,1) # 戻り値 val, idx
>>> torch.std(a, unbiased=True) # 標準偏差 不偏
>>> torch.std(a, unbiased=False) # 標準偏差 不偏でない
>>> torch.var(a) # 分散
>>> stddev, mean = torch.std_mean(a, unbiased=False) # 平均も標準偏差も
>>> stddev
tensor(0.2944)
>>> mean
tensor(0.3600)

>>> torch.sum(a,1)
>>> torch.prod(a,1)

# 累積和, 累積積
>>> torch.cumsum(a,1) # dim=1 累積和
tensor([[0.5084, 1.3044, 1.4933, 1.6510, 2.5215],
        [0.0891, 0.1179, 0.3882, 0.4762, 1.0788]])
>>> torch.cumprod(a,1) # dim=1 累積積
tensor([[5.0843e-01, 4.0468e-01, 7.6438e-02, 1.2058e-02, 1.0496e-02],
        [8.9140e-02, 2.5670e-03, 6.9373e-04, 6.1056e-05, 3.6792e-05]])

# log(sum(exp()))
>>> torch.logsumexp(a,1)
tensor([2.1571, 1.8487])

行列演算

# 行列積 matmul, mm
>>> a = torch.rand(3,2)
>>> b = torch.rand(2,4)
>>> torch.matmul(a,b).size()
torch.Size([3, 4])
>>> torch.mm(a,b).size() # mmはbroadcastしない
torch.Size([3, 4])

# 要素ごとの積 mul (*演算子と同じ)
>>> a = torch.rand(3,4)
>>> b = torch.rand(3,4)
>>> torch.mul(a,b).size()
torch.Size([3, 4])
## bloadcastもできる
>>> a = torch.rand(3,1)
>>> b = torch.rand(1,4)
>>> torch.mul(a,b).size()
torch.Size([3, 4])

# 内積 dot
torch.dot

# バッチごとの行列積 bmm
>>> a = torch.rand(10,3,2)
>>> b = torch.rand(10,2,4)
>>> torch.bmm(a,b).size()
torch.Size([10, 3, 4])

# 直積 cartesian_prod
>>> a = torch.rand(3)
>>> b = torch.rand(2)
>>> torch.cartesian_prod(a,b) # a,bともに1次元
tensor([[0.1174, 0.7799],
        [0.1174, 0.0018],
        [0.0949, 0.7799],
        [0.0949, 0.0018],
        [0.4312, 0.7799],
        [0.4312, 0.0018]])

>>> torch.trace(a) # trace
>>> torch.tril(a) # 下三角行列
>>> torch.triu(a) # 上三角行列
>>> torch.diag(torch.diagonal(a)) # 対角行列

>>> torch.eig(a) # 固有値固有ベクトル. backwardは未実装. torch.symeig()を利用する
>>> torch.inverse(a)
>>> torch.det(a)
>>> u, s, v = torch.svd(a)

# 便利な演算
>>> torch.addbmm(beta=1, add, alpha=1, a, b) # beta*add + alpha*sum(bmm(a,b),0)
>>> torch.baddbmm(beta=1, add, alpha=1, a, b) # beta*add + alpha*bmm(a,b)
>>> torch.addmm(beta=1, add, alpha=1, a, b) # beta*add + alpha*matmul(a,b)
>>> torch.addmv(beta=1, add, alpha=1, a, vec) # beta*add + alpha*matmul(a,vec)
>>> torch.addr(beta=1, add, alpha=1, vec_a, vec_b) # beta*add + alpha*outerprod(vec_a,vec_b)
>>> torch.chain_matmul(a,b,c) # matmul(matmul(a,b),c)

ノルム/距離

# ノルム
>>> torch.norm(a)
tensor(1.4707) # torch.norm(a.reshape(-1))と同じ
>>> torch.norm(a, dim=1, p='fro') # frobenius norm
tensor([1.3078, 0.6728])

# 距離
>>> a = torch.rand(2)
>>> b = torch.rand(2)
>>> torch.dist(a,b,p=2) # 2ノルム
tensor(0.6112)

# 各行同士の距離
>>> a = torch.rand(3,2)
>>> b = torch.rand(4,2)
>>> torch.cdist(a,b,p=2) # aの各行とbの各行の距離
tensor([[0.2806, 0.5389, 0.2644, 0.3649],
        [0.5951, 0.3302, 0.5304, 0.4894],
        [0.7638, 0.2005, 0.9625, 0.4779]])

そのほかの演算

einsum

文字列で数式を書くことで行列間の演算を定義できる.
基本
torch.einsum("ij,st->ijst",arr1, arr2)
f_ijst = arr1_i,j * arr2_s,t
for each i,j,s,t return f_ijst

torch.einsum("ij,st->ijs",arr1, arr2)
f_ijst = arr1_i,j * arr2_s,t
for each i,j,s return sum_t(f_ijst)

torch.einsum("ij,st->is",arr1, arr2)
f_ijst = arr1_i,j * arr2_s,t
for each i,s return sum_j,t(f_ijst)

>>> a
tensor([[1, 2],
        [3, 4]])
>>> b
tensor([2, 3])

>>> torch.einsum("ij->i",a)
tensor([3, 7])

>>> torch.einsum("ij,s->ijs",a,b)
tensor([[[ 2,  3],
         [ 4,  6]],

        [[ 6,  9],
         [ 8, 12]]])
# f_ijs = a_ij * b_s

>>> torch.einsum("ij,s->ijs",a,b)
tensor([[[ 2,  3],
         [ 4,  6]],

        [[ 6,  9],
         [ 8, 12]]])
# f_ijs = a_ij * b_s

>>> torch.einsum("ij,s->is",a,b)
tensor([[ 6,  9],
        [14, 21]])
# f_ijs = a_ij * b_s
# for each i,s sum_j(f_ijs)

>>> torch.einsum("ij,i->i",a,b)
tensor([[ 6, 21]])
# f_ij = a_ij * b_i
# for each i sum_j(f_ij)

# torch.bmm
>>> torch.einsum('ijk,ikl->ijl',a,b) # 次元iは共通, 出力の形はijl

gather

# gather(input, dim, index) 以下のような処理をします
    out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
    out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
    out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2
>>> a = torch.tensor([[1,2,3],[4,5,6],[7,8,9]])
>>> b = torch.eye(3).to(int)
>>> a
tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])
>>> torch.gather(a,1,b)
tensor([[2, 1, 1], # idx = 1,0,0
        [4, 5, 4], # idx = 0,1,0
        [7, 7, 8]])# idx = 0,0,1

unique/連続している要素だけunique/出現回数をカウント

# unique
>>> torch.unique(torch.tensor([2,1,3,4,2,3]),sorted=False))
tensor([4,3,1,2])
>>> torch.unique(torch.tensor([2,1,3,4,2,3]),sorted=True)) #default True
tensor([1,2,3,4])

# uniqしたtensorをもとに戻すインデックスを得ることができる
>>> out, inv_idx = torch.unique(torch.tensor([2,1,3,4,2,3]),return_inverse=True))
>>> out
tensor([1, 2, 3, 4])
>>> inv_idx
tensor([1, 0, 2, 3, 1, 2])
>>> out[inv_idx]
tensor([2,1,3,4,2,3])

# 出現回数をカウントできる
>>> out, cnt = torch.unique(torch.tensor([2,1,3,4,2,3]), return_counts=True)
>>> out
tensor([1, 2, 3, 4])
>>> cnt
tensor([1, 2, 2, 1])

# unique_consecutive 連続する重複する値を削除
>>> torch.unique_consecutive(torch.tensor([1,2,2,2,3,3,1,2]))
tensor([1, 2, 3, 1, 2]) # torch.uniqueならtensor([1,2,3])になる

線形補完

# 線形補完 torch.lerp(start, end, weight)
>>> torch.lerp(torch.tensor([1,2,3],dtype=float), torch.tensor([2,6,5],dtype=float), 0.25)
tensor([1.2500, 3.0000, 3.5000], dtype=torch.float64)