sigmoid_cross_entropy_with_logitsクロスエントロピー損失の概要とテスト
2608 ワード
def
この関数の役割はsigmoid関数によって活性化された後のクロスエントロピーを計算することである.
def sigmoid_cross_entropy_with_logits(_sentinel=None, labels=None, logits=None, name=None):
計算式:
簡潔に記述するために,x=logits(例えば図),z=targets(分類結果)を規定した.
z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
簡略化する
z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
= z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
= z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
= z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
= (1 - z) * x + log(1 + exp(-x))
= x - x * z + log(1 + exp(-x))
x<0の場合、すなわちxが小さな負数の場合、e^{−x}が無限大になると、logを取った後に無限大になる.誤報を招く
RuntimeWarning: overflow encountered in exp
exp(-x)の計算時にオーバーフローを回避するために、以下の形式で表します.
x - x * z + log(1 + exp(-x))
= log(exp(x)) - x * z + log(1 + exp(-x))
= - x * z + log(1 + exp(x))
しかし、実際にはx>0に対して依然としてオーバーフローであり、総合的に考慮すると、
max(x,0)−x∗z+log(1+exp(−abs(x)))
これもtensorflowで採用されている式です.次のコードをテストします.
code
import tensorflow as tf
import numpy as np
def sigmod(x):
return 1.0/(1+np.exp(-x))
logits=np.array([[1.,-810.,20.],[11.,12.,14.],[12.,21.,23.]])
labels=np.array([[1.,0.,0.],[0.,1.,0.],[0.,0.,1.]])
y_predict=sigmod(logits)
loss_1=logits*(1-labels)+np.log(1+np.exp(-logits))
print('
',loss_1)
print('------------------')
print('tensorflow
')
with tf.Session() as sess:
print(sess.run(tf.nn.sigmoid_cross_entropy_with_logits(labels=labels,logits=logits)))
print('------------------')
print('
')
loss_2=np.maximum(logits,0)-logits*labels+np.log(1+np.exp(-np.abs(logits)))
print('
',loss_2)
out:
[[ 3.13261688e-01 inf 2.00000000e+01]
[ 1.10000167e+01 6.14419348e-06 1.40000008e+01]
[ 1.20000061e+01 2.10000000e+01 1.02618802e-10]]
------------------
tensorflow
[[ 3.13261688e-01 0.00000000e+00 2.00000000e+01]
[ 1.10000167e+01 6.14419348e-06 1.40000008e+01]
[ 1.20000061e+01 2.10000000e+01 1.02618796e-10]]
------------------
[[ 3.13261688e-01 0.00000000e+00 2.00000000e+01]
[ 1.10000167e+01 6.14419348e-06 1.40000008e+01]
[ 1.20000061e+01 2.10000000e+01 1.02618802e-10]]
ref https://blog.csdn.net/m0_37393514/article/details/81393819 https://www.cnblogs.com/cloud-ken/p/7435421.html