tf.one_hot()は独熱符号化を行う
2485 ワード
tf.one_hot()は独熱符号化を行う
まず、独熱符号化(one-hot encoding)とは何かを説明する必要があるに違いない.例を挙げると、例えば5種類の分類問題があり、データ(xi,yi)があり、カテゴリyiには5種類の値(5種類の分類問題であるため)があるため、yjが第1クラスである場合、その独熱符号化は[1,0,0,0,0]であり、第2クラスである場合、独熱符号化は[0,1,0,0,0]であり、すなわち、そのカテゴリが存在する数の位置のみを1とし、その他はすべて0である.この符号化方式は多分類問題,特に損失関数がクロスエントロピー関数である場合によく用いられる.次に、TensorFlowに搭載されているデータの独熱符号化関数
tf.one_hot()
について説明します.まず、APIマニュアルを貼り付けます.one_hot(
indices,
depth,
on_value=None,
off_value=None,
axis=None,
dtype=None,
name=None
)
indices
、depth
を指定する必要があります.ここで、depth
は符号化深度であり、on_value
、off_value
は符号化後の開閉値に相当します.先ほど説明したX値とY値と同様に、dtype
と同じタイプ(dtype
を指定した場合)が必要です.axis
は符号化の軸を指定します.ここでは、小さな例を示します.var = tf.one_hot(indices=[1, 2, 3], depth=4, axis=0)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
a = sess.run(var)
print(a)
しゅつりょく
[[ 0. 0. 0.]
[ 1. 0. 0.]
[ 0. 1. 0.]
[ 0. 0. 1.]]
axis=0
軸に符号化するため、depth
は4であり、それに相当して列方向に広がる.axis
を1に変更すると、次のようになります.[[ 0. 1. 0. 0.]
[ 0. 0. 1. 0.]
[ 0. 0. 0. 1.]]