Pytorchではlabelをonehot符号化の2つの方式に変える

3509 ワード

PytorchはTensorFlowのようにグーグルの大手がメンテナンスをしていないため、多くの機能は高度なパッケージではありません.例えばtfがありません.one_hot関数.本編ではmini batchのlabelベクトルを[batch size,class numbers]として形状を変えたone hot符号化の2つの方法を紹介し,
tensor.scatter_ tensor.index_select

scatter_の使用onehot符号化を取得


CSDNでこの関数の使い方を探している人はみな公式の紹介が分からないと信じています.だから、私は他の場所のように公式のチュートリアルを運ぶことはありません.私も長い間考えていましたが、関数の声明はやはり見なければなりません.
tensor.scatter_(dim, index, src) 
dim :  。 。 0  sum(tensor.shape)-1
index :  src tensor 。index shape src , src shape 。
src :  tensor 

まず、別のブログcopyから例を見てみましょうが、もっと詳しく紹介します.いい話だと思ったらメッセージを残して励ましてください.
>>> x = torch.rand(2, 5)
>>> x

 0.4319  0.6500  0.4080  0.8760  0.2355
 0.2609  0.4711  0.8486  0.8573  0.1029
[torch.FloatTensor of size 2x5]

>>> torch.zeros(3, 5).scatter_(0, torch.LongTensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)

 0.4319  0.4711  0.8486  0.8760  0.2355
 0.0000  0.6500  0.0000  0.8573  0.0000
 0.2609  0.0000  0.4080  0.0000  0.1029
[torch.FloatTensor of size 3x5]

dimは0であり、最初の次元を頼りにしていることに注意してください.indexは2次元配列です.[0,1,2,0,0][2,0,0,1,2]ではtensorを上書きする位置は10個あり、それぞれ
[0,0];[1,1];[2,2];[0,3];[0,4]
[2,0];[0,1];[0,2];[1,3];[2,4]

dimはindexがindexの値をどの軸の値とするかを指定します.他の軸は0からmax shape-1に変化するだけです.例えばdimが0である場合、indexの値は座標の最初の位置の値として、別の位置は0から4に変換されます.この10の位置がカバーされているかどうかを検証することができます.10の位置の最初の軸はindexの数字で、2番目の数字はindexの列数で、0から4です.上書きする場所がありますが、どの値で上書きしますか?私たちのindexの次元はsrcと同じであることを忘れないでください.indexでどの位置の座標を選択するかは、srcが適用される位置の値に置き換えられます.たとえばtensorの[0,0]の値を置き換えると、indexの[0,0]は0行目の0列目に対応する位置であり、tensorの値をsrc 0行目の0列目の値に置き換えます.検証してみてください.
次の状況を見てみましょう.dimが1なら.
>>> z = torch.zeros(2, 4).scatter_(1, torch.LongTensor([[2], [3]]), 1.23)
>>> z
 

dimが1である場合、indexの値は座標の2番目の位置の値として、1番目の位置の値は0から1まで変化するべきである.だから代わられる位置には
[0,2];[1,3]

一方、[0,2]の位置に入力する値は1.23、[1,3]の値は1.23です.(放送メカニズムは1.23というスカラーをshapeが(2,1)に拡張した)
はい、関数の使い方がわかりました.この関数を用いてlabelをonehot符号化する方法を見てみましょう.
まずbatch sizeが8のlabelを想定する.10種類あるので、labelの数字は0から9のはずです.
import torch as t
import numpy as np

batch_size = 8
class_num = 10
label = np.random.randint(0,class_num,size=(batch_size,1))
label = t.LongTensor(label)

label,shapeは(8,1),必ず2次元を得た.(8,)次の内容であれば誤報になります.
y_one_hot = t.zeros(batch_size,class_num).scatter_(1,label,1)
print(y_one_hot)

'''
tensor([[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.]])
'''

片付ける.次の方法を見てみましょう.

tensorを使用index_selectはonehot符号化を取得する


まずindexを見てみましょうselectの使い方.
tensor.index_select( dim, index, out=None)
dim :  tensor 
index :  。 dim tensor index 。

例を見ないで、直接方法を見て、これを例にします.
    ones = torch.sparse.torch.eye(class_num)
    return ones.index_select(0,label)


ここのlabelは1次元のベクトルで、2次元ではありません.indexは1次元でなければならない単位行列を作成したので、サイズは[class_num,class_num]です.dimは0であり,これは行に従ってtensorのベクトルをとると考えられる.具体的にどの行を取るかは、labelの値です.このとき、なぜこの2行のコードがone hot符号化を実現できるのか、私たちも知っているだろう.labelが[1,3,0]の場合、4種類あります.では、[0,1,0,0][0,0,0,1][1,0,0]を得ます.
△良質なブログ、転載は出典を明記してください.