Pytorch学習(二十二)soft labelのクロスエントロピーlossの実現
11384 ワード
いつも言う
参照リンク:
まず情報エントロピー,クロスエントロピー,相対エントロピーを理解する
まずクロスエントロピーの定義を探します:1)情報エントロピー:符号化スキームが完璧な場合、最短平均符号化長はいくらですか.2)クロスエントロピー:符号化スキームが必ずしも完璧ではない場合(確率分布の推定が必ずしも正確ではないため)、平均符号化長はいくらであるか.平均符号化長=最短平均符号化長+1増分3)相対エントロピー:符号化スキームが必ずしも完璧ではない場合、平均符号化長の最小値に対する増加値.(すなわち、上記の増分)(すなわち、相対エントロピーは情報利得であり、KL分散である)
作者:張一山リンク:https://www.zhihu.com/question/41252833/answer/140950659出典:著作権は作者の所有であることを知っている.商業転載は著者に連絡して許可を得てください.非商業転載は出典を明記してください.
じょうほうエントロピー
(1)については何も言わないが,ある事柄の確率が既知であれば,直接H(X)=−Σi=1 np(x i)l o g(p(x i))H(X)=−sum_{i=1}^{n}p(x_i)log(p(x_i)) H(X)=−∑i=1np(xi)log(p(xi)). この表現は,イベントXに対してx 1,⋅,x i,⋅,x n x_のみである.1,\cdot , x_i,\cdot, x_nx 1,⋅,xi,⋅,xn種の場合,x i x_i xiが発生する確率はp(x i)p(x_i)p(xi)であり、直接−p l o g(p)−plog(p)−plog(p)−plog(p)OK?少し理由を言うと、主にl o g(p(x i))log(p(x_i))log(p(xi))の代表がx i x_であるixiイベントのエントロピー,すなわち情報量.x 1 x_1 x 1およびx 2 x_2 x 2は独立して同分布し、2つが同時に発生すると、確率はp(x 1)p(x 2)p(x_1)p(x_2)p(x 1)p(x 2)であり、情報量は累積されるべきである.すなわち、H(p(p(p(x 1)p(x 2)=H(p(x 1)+H(p(p(x 2))H(p(x_1)p(x_2))=H(p(p(x_1)+H(p(p(x 1)+H(p(p(x 1)p(x 2)=H(p(p(x 1)+H(p(x 1)+H(p(x 2)=H(p(x 1)+x 2)+H(p(x 2)+)であるため、H(x)H(x)H(x)は対数関数である.底は2,e,10 2,e,10 2,e,10を取ることができ、問題は大きくありません.最後に、p p p pは一般的に小数なので、情報である以上正数が合っているので、
を付けて席にいるのは文句ないでしょう~.まとめ:−p l o g(p)−plog(p)−plog(p)はエントロピーである.(もちろん実際には、累加して、このように書くのはただ覚えやすいためです)そうたいエントロピー
相対エントロピーはKL分散であり、よく考えてみると、p p pの実際の分布を知ることができないので、q qを予測し、このq qの分布とp p p pができるだけ近いことを望んで、KL分散を測定とすることができます.画像の違い、L 2損失、PSNR、SSIMなど、興味深いことを考えてみてください.これらは測定方法で、違いを表しています.確率分布の違いを説明しますか?直接減算することはできないでしょう.これを使ってもいいです(もちろん、Wasserstein distanceのような他の説明はたくさんありますが、KL分散の代わりに、WGANを引用しています).
ここで、KL分散の定義は、DKL(p∣∣q)=Σi=1 np(x i)l o g(p(x i)q(x i)D_{KL}(p||q)=\sum_{i=1}^{n}p(x_i)log(frac{p(x_i)}{q(x_i)}DKL(pQの分布はできるだけPに近いようにしてください.一度訓練して、Q分布でPを表す場合、追加の情報が必要なのはDKL(p∣∣q)D_である.{KL}(p||q)DKL(pだから分類はq(x)q(x)q(x)q(x)をできるだけp(x)p(x)p(x)に近づけることである.
クロスエントロピー
D K L(p∣∣q)=−H(p(x))+[−Σi=1 np(x i)l o g(q(x i))]D_{KL}(p||q)=−H(p(x)+[−sum_{i=1}^{n}p(x_i)log(q(x_i))]DKL(p−H(p(x)−H(p(x))−H(p(x))−H(p(x)))は変わらないので、トレーニングセットの自然な法則を表しましょう(例えば、猫である確率がp(x=猫)であるp(x 1)p(x 1)p(x 1)であり、犬であるp(x=犬)p(x=犬)p(x=犬)p(x 2)p(x 2)である).クロスエントロピーを見てみましょう
H ( p , q ) = − ∑ i = 1 n p ( x i ) l o g ( q ( x i ) ) H(p,q) = -\sum_{i=1}^{n}p(x_i)log(q(x_i))H(p,q)=−i=1Σnp(xi)log(q(xi))は、推定された確率q q qで符号化され、必要な符号長を表す.
対照的に:エントロピーとクロスエントロピーの形式(記憶を容易にするためだけ):H(X)=−Σi=1 np(x i)l o g(p(x i)))H(X)=−sum_{i=1}^{n}p(x_i)log(p(x_i)) H(X)=−i=1∑np(xi)log(p(xi)) H ( p , q ) = − ∑ i = 1 n p ( x i ) l o g ( q ( x i ) ) H(p,q) = -\sum_{i=1}^{n}p(x_i)log(q(x_i))H(p,q)=−i=1Σnp(xi)log(q(xi))**はいずれも−pl o g(p)−plog(p)−plog(p)**形式であり,p p pの確率が分からない場合は推定されたq q q qで置き換えてl o g logに詰め込む.
分類におけるクロスエントロピー
クロスオーバーエントロピーを熟練して書き出してみると、問題なくq(x i)q(x_i)q(xi)は、この画像I Iを入力すると、ネットワークの出力の確率(softmaxを通過した後)H(p,q)=−Σi=1 np(x i)l o g(q(x i))H(p,q)=−sum_{i=1}^{n}p(x_i)log(q(x_i))H(p,q)=−i=1Σnp(xi)log(q(xi))この画像が猫である場合、p(I=猫)=p(x 1)=1 p(I=猫)=p(x_1)=p(I=猫)=p(x_1)=p(I=猫)=p(x 1)=p(x 1)=p(x 1)=1)であり、その他の確率は0である.
それは実は簡単なので、一般的なクロスエントロピーの計算は以下の通りです:l o s=−Σi=1 Ky i l o g(y i^)loss=−sum_{i=1}^{K}y_{i}log(hat{y_i})loss=−i=1ΣK yi log(yi^)は、I I Iを入力したときに、この画像と実分布の損失を示す.△クロスエントロピーであり、log種の各クラスの確率を予測すればよい.通常の分類であれば、各ピクチャについて損失値は、−l o g(正しいカテゴリが予測された確率)−log(正しいカテゴリが予測された確率)−log(正しいカテゴリが予測された確率)−log(正しいカテゴリが予測された確率)であり、実際のラベルはhard labelが一般的に採用されているため、p(この図は正しいカテゴリに属する)=1 p(この図は正しいカテゴリに属する)=1 p(この図は正しいカテゴリに属する)=1 p(この図は正しいカテゴリに属する)=1であり、p(図が他のカテゴリに属する)=0 p(図が他のカテゴリに属する)=0 p(図が他のカテゴリに属する)=0.もちろん正規書きの場合:l o s s=−1 NΣj=1 NΣi=1 ny j i l o g(y j i^)loss=−frac{1}{N}sum_{j=1}^{N}\sum_{i=1}^{n}y_{ji}log(hat{y_{ji}})loss=−N 1 j=1ΣNi=1Σnyji log(yji^)ここで、N Nはbatchのピクチャ数である.
SoftCrossEntropy
実は式に従えば良いのですが、l o s=−1 NΣj=1 NΣi=1 ny j i l o g(y j i^)loss=−frac{1}{N}sum_{j=1}^{N}\sum_{i=1}^{n}y_{ji}log(hat{y_{ji}})loss=−N 1 j=1ΣNi=1Σnyji log(yji^)ただ、一般的には画像ごとにK個のカテゴリの値が加算されているように見えますが、実際には1個の値しかありません.ソフトなら本当にK K個の値が加算されます.
すると、次のようなものがありました.import torch
import torch.nn.functional as F
def SoftCrossEntropy(inputs, target, reduction='sum'):
log_likelihood = -F.log_softmax(inputs, dim=1)
batch = inputs.shape[0]
if reduction == 'average':
loss = torch.sum(torch.mul(log_likelihood, target)) / batch
else:
loss = torch.sum(torch.mul(log_likelihood, target))
return loss
注意点:targetはsoftmax正規化された値、すなわち真実確率y j i y_を表すji yji.2分類であれば、i=1またはi=2 i=1またはi=2 i=1またはi=2である.一方inputsはネットワークの直接出力(ボリューム層やfcの出力でソフトmaxは通っていない)なので、−l o g(q)−log(q)−log(q)でしょうから、ここでは-F.log_softmax
を使います.もちろん,最後に−pl o g(q)−plog(q)−plog(q)はtargetに直接乗せればよい.
普通の分類はこのようにして、SSDなどであれば、彼は実はfeatureの点(pixel levelの小さな枠に対応)ごとに分類します.例えば、inputs
はM*C
であり、そのうちM=N*out_dim
である.すなわち、ネットワーク出力がout_dim
次元(例えばssd,2分類、ネットワーク出力8000以上の予選鉱)、N
がbatchsizeであると仮定すると、直接前の2つの次元が直接合併してもよいかどうかは、各点(N*out_dim
のような複数のfeature点(または小枠))が分類される.各ボックスの分類に相当します.
追加:二分類と多分類の違い
最も簡単な分類は二分類で、ここで言う二分類は、このクラスではないかを指します.二分類であれば、実は最後のニューロンでいいので、BCEWithLogitsLoss()
またはnn.Sigmoid
の後にBCELoss
を加えます.このときソフトmaxで正規化する必要はありません.AとBの2つのカテゴリがある場合は、最後に2つのニューロンが必要です.多分類はLogSoftmax()
の後にNLLLoss()
が続く.>>> # 2D loss example (used, for example, with image inputs)
>>> N, C = 5, 4
>>> loss = nn.NLLLoss()
>>> # input is of size N x C x height x width
>>> data = torch.randn(N, 16, 10, 10)
>>> conv = nn.Conv2d(16, C, (3, 3))
>>> m = nn.LogSoftmax(dim=1)
>>> # each element in target has to have 0 <= value < C
>>> target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C)
>>> output = loss(m(conv(data)), target)
>>> output.backward()
このうちネットワークの出力の最大値(LogSoftMaxより前で良い)が、この画像のカテゴリです.
import torch
import torch.nn.functional as F
def SoftCrossEntropy(inputs, target, reduction='sum'):
log_likelihood = -F.log_softmax(inputs, dim=1)
batch = inputs.shape[0]
if reduction == 'average':
loss = torch.sum(torch.mul(log_likelihood, target)) / batch
else:
loss = torch.sum(torch.mul(log_likelihood, target))
return loss
最も簡単な分類は二分類で、ここで言う二分類は、このクラスではないかを指します.二分類であれば、実は最後のニューロンでいいので、
BCEWithLogitsLoss()
またはnn.Sigmoid
の後にBCELoss
を加えます.このときソフトmaxで正規化する必要はありません.AとBの2つのカテゴリがある場合は、最後に2つのニューロンが必要です.多分類はLogSoftmax()
の後にNLLLoss()
が続く.>>> # 2D loss example (used, for example, with image inputs)
>>> N, C = 5, 4
>>> loss = nn.NLLLoss()
>>> # input is of size N x C x height x width
>>> data = torch.randn(N, 16, 10, 10)
>>> conv = nn.Conv2d(16, C, (3, 3))
>>> m = nn.LogSoftmax(dim=1)
>>> # each element in target has to have 0 <= value < C
>>> target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C)
>>> output = loss(m(conv(data)), target)
>>> output.backward()
このうちネットワークの出力の最大値(LogSoftMaxより前で良い)が、この画像のカテゴリです.