sigmoid_cross_entropy_with_logitsクロスエントロピー損失の概要とテスト

2608 ワード

def


この関数の役割はsigmoid関数によって活性化された後のクロスエントロピーを計算することである.
def sigmoid_cross_entropy_with_logits(_sentinel=None,  labels=None, logits=None,  name=None):

計算式:


簡潔に記述するために,x=logits(例えば図),z=targets(分類結果)を規定した.
z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))

簡略化する
z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
      = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
      = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
      = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
      = (1 - z) * x + log(1 + exp(-x))
      = x - x * z + log(1 + exp(-x))

x<0の場合、すなわちxが小さな負数の場合、e^{−x}が無限大になると、logを取った後に無限大になる.誤報を招く
RuntimeWarning: overflow encountered in exp

exp(-x)の計算時にオーバーフローを回避するために、以下の形式で表します.
 x - x * z + log(1 + exp(-x))
      = log(exp(x)) - x * z + log(1 + exp(-x))
      = - x * z + log(1 + exp(x))

しかし、実際にはx>0に対して依然としてオーバーフローであり、総合的に考慮すると、
max(x,0)−x∗z+log(1+exp(−abs(x)))

これもtensorflowで採用されている式です.次のコードをテストします.

code


import tensorflow as tf
import numpy as np

def sigmod(x):
    return 1.0/(1+np.exp(-x))

logits=np.array([[1.,-810.,20.],[11.,12.,14.],[12.,21.,23.]])
labels=np.array([[1.,0.,0.],[0.,1.,0.],[0.,0.,1.]])

y_predict=sigmod(logits)

loss_1=logits*(1-labels)+np.log(1+np.exp(-logits))
print(' 
',loss_1) print('------------------') print('tensorflow
') with tf.Session() as sess: print(sess.run(tf.nn.sigmoid_cross_entropy_with_logits(labels=labels,logits=logits))) print('------------------') print('
') loss_2=np.maximum(logits,0)-logits*labels+np.log(1+np.exp(-np.abs(logits))) print('
',loss_2)

out:
 
 [[  3.13261688e-01              inf   2.00000000e+01]
 [  1.10000167e+01   6.14419348e-06   1.40000008e+01]
 [  1.20000061e+01   2.10000000e+01   1.02618802e-10]]
------------------
tensorflow 

[[  3.13261688e-01   0.00000000e+00   2.00000000e+01]
 [  1.10000167e+01   6.14419348e-06   1.40000008e+01]
 [  1.20000061e+01   2.10000000e+01   1.02618796e-10]]
------------------
 

 
 [[  3.13261688e-01   0.00000000e+00   2.00000000e+01]
 [  1.10000167e+01   6.14419348e-06   1.40000008e+01]
 [  1.20000061e+01   2.10000000e+01   1.02618802e-10]]

ref https://blog.csdn.net/m0_37393514/article/details/81393819 https://www.cnblogs.com/cloud-ken/p/7435421.html