Tensorflow演算子で発生したピット(1は1ではありません)
14403 ワード
一言で言えば、Tensorは
問題を発見する
この2,3日kerasで小さなものを作って、metricsをカスタマイズしてモデル関数を評価する必要があります.
その結果、
コードロジックの問題ですか?
そこで私はこのようなコードを書きました
出力ans=3.0、論理に誤りがなく、問題を探し続けます
1いったい1に等しくないの?
テストでnumpyを使っていたことに気づいたのはtensorflow(パソコンのkerasのバックエンド)の問題だったのだろうか.
テストコードは次のとおりです.
上記のコードは
結果は
これは特性ですか?
さんざん苦労して解決策を見つけた
出力
なぜ
実はある
tensorflowを使用する場合に最も一般的な「Python」s collection」は以下の通りです.
リロード_eq__2つのtensor間でelement-wiseの演算を行うと、上記のコードが失効する可能性があります.
参考:Issue:Equality operator not overloaded
!=
と==
をリロードしていないので、tf.equal()
の代わりに==
を使用するべきです.問題を発見する
この2,3日kerasで小さなものを作って、metricsをカスタマイズしてモデル関数を評価する必要があります.
from keras import backend as K
THRESHOLD = 0.5
@staticmethod
def error_rate(y_true, y_pred):
pred = K.tf.where(y_pred >= THRESHOLD, K.tf.ones_like(y_true), K.tf.zeros_like(y_true))
acc = K.tf.where(pred == true, K.ones_like(y_true), K.zeros_like(y_true))
return K.mean(1-acc)
その結果、
error_rate
はどのように訓練しても1.0
であることが分かった.コードロジックの問題ですか?
そこで私はこのようなコードを書きました
import numpy as np
a1 = np.array([0.9, 0.1, 1.])
a2 = np.array([1., 0., 1.])
a1 = np.where(a1 >= 0.5, np.ones_like(a1), np.zeros_like(a1))
ans = np.sum(np.where(a1 == a2, np.ones_like(a1), np.zeros_like(a1)))
print('ans={}'.format(ans))# ans=3.0
出力ans=3.0、論理に誤りがなく、問題を探し続けます
1いったい1に等しくないの?
テストでnumpyを使っていたことに気づいたのはtensorflow(パソコンのkerasのバックエンド)の問題だったのだろうか.
テストコードは次のとおりです.
import tensorflow as tf
def main():
a1 = tf.Variable([1., 0., 1.])
a2 = tf.Variable([1., 0., 1.])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
ans = sess.run(tf.reduce_sum(tf.where(a1 == a2, tf.ones_like(a1), tf.zeros_like(a1))))
print('ans={}'.format(ans))# ans=0.0
if __name__ == '__main__':
main()
上記のコードは
ans=3.0
を出力すべきであるが、実際には以下のように出力される.ans=0.0
結果は
ans=0.0
でしたこれは特性ですか?
さんざん苦労して解決策を見つけた
def main():
a1 = tf.Variable([1., 0., 1.])
a2 = tf.Variable([1., 0., 1.])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
ans = sess.run(tf.reduce_sum(tf.where(tf.equal(a1, a2), tf.ones_like(a1), tf.zeros_like(a1))))
print('ans={}'.format(ans))# ans=3.0
出力
ans=3.0
、問題解決なぜ
__eq__
は重荷重されていないのですか?実はある
tensorflow.Tensor
def __eq__(self, other):
# Necessary to support Python's collection membership operators
return id(self) == id(other)
tensorflowを使用する場合に最も一般的な「Python」s collection」は以下の通りです.
a1 = tf.placeholder(shape=(2, 2), dtype=tf.uint8)
feed_dict = {a1: np.ones((2, 2))}
リロード_eq__2つのtensor間でelement-wiseの演算を行うと、上記のコードが失効する可能性があります.
参考:Issue:Equality operator not overloaded