TensorFlow関数:tf.where()

12268 ワード

focal lossでクロスエントロピー関数を実現するには、label画像の異なるclassに対応する位置を取り出す必要があり、クエリーでtfを用いる必要がある.where関数は、次のように記録されます.

Tensorflow公式ドキュメントの紹介


Returns locations of true values in a boolean tensor.This operation returns the coordinates of true elements in input . The coordinates are returned in a 2-D tensor where the first dimension (rows) represents the number of true elements, and the second dimension (columns) represents the coordinates of the true elements. Keep in mind, the shape of the output tensor can vary depending on how many true values there are in input . Indices are output in row-major order.
For example:
# 'input' tensor is [[True, False]
#                    [True, False]]
# 'input' has two true values, so output has two coordinates.
# 'input' has rank of 2, so coordinates have two indices.
where(input) ==> [[0, 0],
                  [1, 0]]

# `input` tensor is [[[True, False]
#                     [True, False]]
#                    [[False, True]
#                     [False, True]]
#                    [[False, False]
#                     [False, True]]]
# 'input' has 5 true values, so output has 5 coordinates.
# 'input' has rank of 3, so coordinates have three indices.
where(input) ==> [[0, 0, 0],
                  [0, 1, 0],
                  [1, 0, 1],
                  [1, 1, 1],
                  [2, 1, 1]]

Args:
  • input: A Tensor of type bool.
  • name: A name for the operation (optional).

  • Returns:
  • A Tensor of type int64.

  • 用法2


    公式文書にはtf.where(input, name=None)の1つの用法しかなく、実際の応用では、a,bはいずれも寸法が一致するtf.where(input, a,b)であり、a中のtensor中のinputに対応する位置の要素値を不変にし、残りの要素を置換し、b中の対応する位置の要素値に置き換える役割を果たすもう1つの使用方法trueが発見された.例:
    import tensorflow as tf
    import numpy as np
    sess=tf.Session()
    
    a=np.array([[1,0,0],[0,1,1]])
    a1=np.array([[3,2,3],[4,5,6]])
    index = tf.where(a)
    
    print(sess.run(index))
    
    print(sess.run(tf.equal(a,1)))
    print(sess.run(tf.where(tf.equal(a,1),a1,1-a1)))
    
    print(sess.run(tf.equal(a,0)))
    print(sess.run(tf.where(tf.equal(a,0),a1,1-a1)))
    

    出力はそれぞれ
    [[0 0]
     [1 1]
     [1 2]]
    
    [[ True False False]
     [False  True  True]]
     
    [[ 3 -1 -2]
     [-3  5  6]]
     
    [[False  True  True]
     [ True False False]]
     
    [[-2  2  3]
     [ 4 -4 -5]]
    
    

    参照先:https://blog.csdn.net/a_a_ron/article/details/79048446