TensorFlow one-hot符号化tf.one_hotの基本的な使い方とインスタンスコード

2915 ワード

一、環境
TensorFlow API r1.12
CUDA 9.2 V9.2.148
cudnn64_7.dll
Python 3.6.3
Windows 10
 
二、公式説明
入力したindicesをone-hot符号化形式に変換
indicesで指定した位置の値はone_です.valueパラメータ値、その他の位置はoff_を取りますvalueパラメータ値
パラメータone_valueとパラメータoff_valueのデータ型は同じでなければなりません.dtypeを指定した場合は、すべてそのデータ型でなければなりません.
パラメータone_valueは指定されていません.デフォルトは1で、タイプは指定されたdtypeです.
パラメータoff_の場合valueは指定されていません.デフォルトは0で、タイプは指定されたdtypeです.
入力パラメータindicesの次数がNである場合、出力データの次数N+1;新しい軸はパラメータaxisの次元に追加されます(axisを指定しない場合はデフォルトで一番後ろの次元に追加されます).
indicesがスカラーの場合、出力結果の形状は長さdepthのベクトルである
indicesがfeaturesの長さのベクトルである場合、出力結果の形状は次のとおりです.
features x depth if axis == -1
depth x features if axis == 0

indicesが[batch,features]形状のマトリクスの場合、出力結果の形状は:
batch x features x depth if axis == -1
batch x depth x features if axis == 1
depth x batch x features if axis == 0

パラメータdtypeが指定されていない場合、このメソッドはデフォルトでデータフォーマットとパラメータon_を仮定します.valueまたはoff_valueは同じで、dtype、on_valueとoff_valueが指定されていない場合、dtypeのデフォルトはtfです.float32
注意:出力結果が非数値形式の場合、例えば:tf.string、tf.boolなどであれば、on_valueとoff_valueはすべて設定する必要があります
https://tensorflow.google.cn/api_docs/python/tf/one_hot
tf.one_hot(
    indices,
    depth,
    on_value=None,
    off_value=None,
    axis=None,
    dtype=None,
    name=None
)

パラメータ:
indices:インデックスの値を持つテンソル
depth:独熱符号化次元のスカラーを指定する
on_value:インデックスindices[j]=iの位置に埋め込まれたスカラー、デフォルトは1
off_value:インデックスindices[j]!=iすべての場所に埋め込まれたスカラー、デフォルトは0
axis:塗りつぶした軸、デフォルトは-1(一番奥の新しい軸)
dtype:出力テンソルのデータフォーマット
name:オプションパラメータ、アクションの名前
戻り値:
ユニホットコーディングone-hotテンソル
 
三、実例
(1)一次元リスト形式の整数カテゴリラベルをone-hotカテゴリラベル形式に変換する
>>> import tensorflow as tf
>>> labels = [0,1,2]
>>> one_hot_labels = tf.one_hot(indices=labels,depth=3, on_value=1, off_value=0, axis=-1, dtype=tf.int32, name="one-hot")
>>> one_hot_labels

>>> with tf.Session() as sess:
...     print(sess.run(one_hot_labels))
...
[[1 0 0]
 [0 1 0]
 [0 0 1]]

(2)2 Dリスト形式の整数カテゴリラベルをone-hotカテゴリラベル形式に変換
>>> import tensorflow as tf
>>> labels = [[0,1],[2,3]]
>>> labels
[[0, 1], [2, 3]]
>>> one_hot_labels = tf.one_hot(indices=labels,depth=3, on_value=1.0, off_value=0.0, axis=-1)
>>> one_hot_labels

>>> with tf.Session() as sess:
...     print(sess.run(one_hot_labels))
...
[[[1. 0. 0.]
  [0. 1. 0.]]

 [[0. 0. 1.]
  [0. 0. 0.]]]

四、注意事項
パラメータ「dtype」を使用して出力テンソルのデータフォーマットを定義する場合は、必ずパラメータ「on_value」と「off_value」のデータフォーマットが対応しなければなりません.そうしないと、エラーが発生します.
例:dtype=tf.float 32,on_value=1, off_value=0、すなわち前者は浮動小数点型を指し、後者は整形であり、エラーを報告する.
“TypeError: dtype of on_value does not match dtype parameter ”