Tensorflow演算子で発生したピット(1は1ではありません)

14403 ワード

一言で言えば、Tensorは!===をリロードしていないので、tf.equal()の代わりに==を使用するべきです.
問題を発見する
この2,3日kerasで小さなものを作って、metricsをカスタマイズしてモデル関数を評価する必要があります.
from keras import backend as K

THRESHOLD = 0.5

@staticmethod
def error_rate(y_true, y_pred):
    pred = K.tf.where(y_pred >= THRESHOLD, K.tf.ones_like(y_true), K.tf.zeros_like(y_true))
    acc = K.tf.where(pred == true, K.ones_like(y_true), K.zeros_like(y_true))
    return K.mean(1-acc)

その結果、error_rateはどのように訓練しても1.0であることが分かった.
コードロジックの問題ですか?
そこで私はこのようなコードを書きました
	import numpy as np
	a1 = np.array([0.9, 0.1, 1.])
    a2 = np.array([1., 0., 1.])
    a1 = np.where(a1 >= 0.5, np.ones_like(a1), np.zeros_like(a1))
    ans = np.sum(np.where(a1 == a2, np.ones_like(a1), np.zeros_like(a1)))
    print('ans={}'.format(ans))# ans=3.0

出力ans=3.0、論理に誤りがなく、問題を探し続けます
1いったい1に等しくないの?
テストでnumpyを使っていたことに気づいたのはtensorflow(パソコンのkerasのバックエンド)の問題だったのだろうか.
テストコードは次のとおりです.
import tensorflow as tf


def main():
    a1 = tf.Variable([1., 0., 1.])
    a2 = tf.Variable([1., 0., 1.])
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        ans = sess.run(tf.reduce_sum(tf.where(a1 == a2, tf.ones_like(a1), tf.zeros_like(a1))))
        print('ans={}'.format(ans))# ans=0.0

if __name__ == '__main__':
    main()

上記のコードはans=3.0を出力すべきであるが、実際には以下のように出力される.
ans=0.0

結果はans=0.0でした
これは特性ですか?
さんざん苦労して解決策を見つけた
def main():
    a1 = tf.Variable([1., 0., 1.])
    a2 = tf.Variable([1., 0., 1.])
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        ans = sess.run(tf.reduce_sum(tf.where(tf.equal(a1, a2), tf.ones_like(a1), tf.zeros_like(a1))))
        print('ans={}'.format(ans))# ans=3.0

出力ans=3.0、問題解決
なぜ__eq__は重荷重されていないのですか?
実はある
tensorflow.Tensor

  def __eq__(self, other):
    # Necessary to support Python's collection membership operators
    return id(self) == id(other)

tensorflowを使用する場合に最も一般的な「Python」s collection」は以下の通りです.
a1 = tf.placeholder(shape=(2, 2), dtype=tf.uint8)
feed_dict = {a1: np.ones((2, 2))}

リロード_eq__2つのtensor間でelement-wiseの演算を行うと、上記のコードが失効する可能性があります.
参考:Issue:Equality operator not overloaded